Skip to content

Commit 3ae6ee4

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Move CreateContext to global registry (pytorch#11688)
Summary: Pull Request resolved: pytorch#11688 As a first step to remove static context(merge with allocator), we'll create a global registries for context constructors, and remove CreateContext function from tensor. Reviewed By: ezyang, dzhulgakov Differential Revision: D9779821 fbshipit-source-id: 8b239ea50af7a0556fde2382f58f79194f0e3dc1
1 parent b7c302d commit 3ae6ee4

21 files changed

+122
-94
lines changed

aten/src/ATen/core/context_base.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#include <ATen/core/context_base.h>
22

3+
namespace at {
4+
5+
AT_DEFINE_TYPED_REGISTRY(
6+
ContextRegistry,
7+
DeviceType,
8+
BaseContext,
9+
std::unique_ptr,
10+
at::Device);
11+
12+
} // namespace at
13+
314
namespace caffe2 {
415

516
// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h

aten/src/ATen/core/context_base.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
#include <memory>
77
#include <unordered_map>
88

9-
#include <ATen/core/DeviceType.h>
9+
#include <ATen/core/ATenGeneral.h>
10+
#include <ATen/core/Device.h>
1011
#include <ATen/core/Error.h>
12+
#include <ATen/core/Registry.h>
1113
#include <ATen/core/UniqueVoidPtr.h>
1214
#include <ATen/core/typeid.h>
13-
#include <ATen/core/ATenGeneral.h>
1415

1516
namespace caffe2 {
1617
class Event;
@@ -31,11 +32,6 @@ class AT_CORE_API BaseStaticContext {
3132

3233
virtual std::pair<void*, DeleterFnPtr> New(size_t nbytes) const = 0;
3334

34-
virtual std::unique_ptr<BaseContext> CreateContext() = 0;
35-
36-
virtual std::unique_ptr<BaseContext> CreateContext(
37-
const caffe2::DeviceOption&) = 0;
38-
3935
virtual DeviceType GetDeviceType() = 0;
4036

4137
/*
@@ -184,6 +180,22 @@ class AT_CORE_API BaseContext {
184180
}
185181
};
186182

183+
// Context constructor registry
184+
AT_DECLARE_TYPED_REGISTRY(
185+
ContextRegistry,
186+
at::DeviceType,
187+
BaseContext,
188+
std::unique_ptr,
189+
at::Device);
190+
191+
#define REGISTER_CONTEXT(type, ...) \
192+
AT_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)
193+
194+
inline std::unique_ptr<at::BaseContext> CreateContext(
195+
const at::Device& device) {
196+
return ContextRegistry()->Create(device.type(), device);
197+
}
198+
187199
} // namespace at
188200

189201
namespace caffe2 {

caffe2/core/blob_serialization.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void TensorSerializer::Serialize(
196196
const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
197197
proto.set_data_type(data_type);
198198
StoreDeviceDetail(input, &proto);
199-
auto uniq_ptr = input.GetStaticContext()->CreateContext();
199+
auto uniq_ptr = CreateContext(input.GetDevice());
200200
// A lot of copypaste is error prone. Should we create a macro for this?
201201
switch (data_type) {
202202
case TensorProto_DataType_FLOAT:
@@ -370,8 +370,7 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
370370
void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
371371
// We create a local context for deserializing. Since Caffe2 contexts are
372372
// usually lightweight, this should not involve too much overhead.
373-
auto uniq_ptr =
374-
tensor->GetStaticContext()->CreateContext(proto.device_detail());
373+
auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
375374
auto context = uniq_ptr.get();
376375
context->SwitchToDevice(0);
377376
vector<int64_t> dims;

caffe2/core/context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <process.h>
66
#endif
77

8+
namespace at {
9+
10+
REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
11+
} // namespace at
812
namespace caffe2 {
913

1014
uint32_t RandomNumberSeed() {

caffe2/core/context.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class CAFFE2_API CPUContext final : public BaseContext {
5050
: RandomNumberSeed()) {
5151
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
5252
}
53+
explicit CPUContext(const at::Device& device)
54+
: CPUContext(DeviceToOption(device)) {}
5355

5456
~CPUContext() noexcept override {}
5557

@@ -192,15 +194,6 @@ class CAFFE2_API CPUStaticContext : public BaseStaticContext {
192194
return data_and_deleter;
193195
}
194196

195-
std::unique_ptr<BaseContext> CreateContext() override {
196-
return caffe2::make_unique<CPUContext>();
197-
}
198-
199-
std::unique_ptr<BaseContext> CreateContext(
200-
const DeviceOption& option) override {
201-
return caffe2::make_unique<CPUContext>(option);
202-
}
203-
204197
DeviceType GetDeviceType() override {
205198
return CPU;
206199
}

caffe2/core/context_base.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "context_base.h"
22

33
namespace caffe2 {
4+
45
} // namespace caffe2

caffe2/core/context_base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
#include "caffe2/core/common.h"
66
#include "caffe2/core/logging.h"
77
#include "caffe2/proto/caffe2_pb.h"
8+
9+
namespace caffe2 {} // namespace caffe2

caffe2/core/context_gpu.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ CAFFE2_DEFINE_int(
5757
128,
5858
"The threshold in MB on how frequently to report memory changes");
5959

60+
namespace at {
61+
62+
REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext);
63+
} // namespace at
64+
6065
namespace caffe2 {
6166

6267
ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() {

caffe2/core/context_gpu.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
142142
// The default cuda context constructor.
143143
explicit CUDAContext(const int gpu_id = -1);
144144
explicit CUDAContext(const DeviceOption& option);
145+
explicit CUDAContext(const at::Device& device)
146+
: CUDAContext(DeviceToOption(device)) {}
145147

146148
~CUDAContext() override {
147149
if (curand_generator_) {
@@ -385,19 +387,6 @@ class CAFFE2_CUDA_API CUDAStaticContext final : public BaseStaticContext {
385387
public:
386388
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;
387389

388-
std::unique_ptr<BaseContext> CreateContext() override {
389-
return caffe2::make_unique<CUDAContext>();
390-
}
391-
392-
std::unique_ptr<BaseContext> CreateContext(
393-
const DeviceOption& option) override {
394-
return caffe2::make_unique<CUDAContext>(option);
395-
}
396-
397-
std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
398-
return caffe2::make_unique<CUDAContext>(gpu_id);
399-
}
400-
401390
DeviceType GetDeviceType() override {
402391
return CUDA;
403392
}

caffe2/core/hip/context_hip.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ CAFFE2_DEFINE_int(caffe2_gpu_memory_report_interval_mb,
5050
128,
5151
"The threshold in MB on how frequently to report memory changes");
5252

53+
namespace at {
54+
55+
REGISTER_CONTEXT(DeviceType::HIP, caffe2::HIPContext);
56+
} // namespace at
57+
5358
namespace caffe2 {
5459

5560
thread_local ThreadLocalHIPObjects HIPContext::hip_objects_;
@@ -408,13 +413,12 @@ void HIPStaticContext::Delete(void* ptr) {
408413
g_hip_device_affiliation.erase(it);
409414
break;
410415
}
411-
case HipMemoryPoolType::THC:
412-
{
413-
HIP_ENFORCE(g_thc_allocator->Free(ptr));
414-
if (FLAGS_caffe2_gpu_memory_tracking) {
415-
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
416-
}
417-
break;
416+
case HipMemoryPoolType::THC: {
417+
HIP_ENFORCE(g_thc_allocator->Free(ptr));
418+
if (FLAGS_caffe2_gpu_memory_tracking) {
419+
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
420+
}
421+
break;
418422
}
419423
}
420424
}

caffe2/core/hip/context_hip.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ class HIPContext final : public BaseContext {
127127
// The default HIP context constructor.
128128
explicit HIPContext(const int gpu_id = -1);
129129
explicit HIPContext(const DeviceOption& option);
130+
explicit HIPContext(const at::Device& device)
131+
: HIPContext(DeviceToOption(device)) {}
130132

131133
~HIPContext() override {
132134
if (hiprand_generator_) {
@@ -374,19 +376,6 @@ class HIPStaticContext final : public BaseStaticContext {
374376
public:
375377
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;
376378

377-
std::unique_ptr<BaseContext> CreateContext() override {
378-
return caffe2::make_unique<HIPContext>();
379-
}
380-
381-
std::unique_ptr<BaseContext> CreateContext(
382-
const DeviceOption& option) override {
383-
return caffe2::make_unique<HIPContext>(option);
384-
}
385-
386-
std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
387-
return caffe2::make_unique<HIPContext>(gpu_id);
388-
}
389-
390379
DeviceType GetDeviceType() override {
391380
return HIP;
392381
}

caffe2/core/registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class Registerer {
172172
key, \
173173
RegistryName(), \
174174
Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
175-
at::demangle_type<__VA_ARGS__>()); \
175+
at::demangle_type<__VA_ARGS__>()); \
176176
}
177177

178178
// CAFFE_DECLARE_REGISTRY and CAFFE_DEFINE_REGISTRY are hard-wired to use string

caffe2/core/tensor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ class CAFFE2_API Tensor final {
130130
return impl_.get()->GetStaticContext();
131131
}
132132

133-
std::unique_ptr<BaseContext> CreateContext() const {
134-
return impl_.get()->CreateContext();
135-
}
136-
137133
DeviceType GetDeviceType() const {
138134
return impl_.get()->GetDeviceType();
139135
}
140136

137+
at::Device GetDevice() const {
138+
return impl_.get()->GetDevice();
139+
}
140+
141141
void CopyFrom(const Tensor& src, BaseContext* context = nullptr) const {
142142
impl_.get()->CopyFrom(*src.impl_.get(), context);
143143
}

caffe2/core/tensor_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "caffe2/core/tensor_impl.h"
2-
2+
#include "caffe2/core/context_base.h"
33
#include "caffe2/core/flags.h"
44

55
CAFFE2_DEFINE_bool(

caffe2/core/tensor_impl.h

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <ATen/core/DimVector.h>
44
#include <ATen/core/TensorImpl.h>
55
#include <ATen/core/context_base.h>
6-
#include <ATen/core/context_base.h>
76

87
#include "caffe2/core/allocator.h"
98
#include "caffe2/core/common.h"
@@ -112,21 +111,14 @@ class CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
112111
return get_static_context(device_type);
113112
}
114113

115-
/* @brief
116-
* Create a context that has the same device_type
117-
* as the tensor.
118-
* Note that this doesn't support passing in argument
119-
* TODO(jerryzh): move this to a global registry
120-
* that can create context for us
121-
*/
122-
std::unique_ptr<at::BaseContext> CreateContext() const {
123-
return GetStaticContext()->CreateContext();
124-
}
125-
126114
at::DeviceType GetDeviceType() const {
127115
return storage_.device_type();
128116
}
129117

118+
at::Device GetDevice() const {
119+
return storage_.device();
120+
}
121+
130122
/**
131123
* @brief Copies the data from a source tensor, with a contex provided to
132124
* carry out the underlying memcpy operation.
@@ -167,8 +159,12 @@ class CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
167159
// knows how to copy between CPU and that context
168160
if (src.GetDeviceType() != ::at::DeviceType::CPU || GetDeviceType() == ::at::DeviceType::CPU) {
169161
if (!context) {
170-
src.CreateContext()->CopyBytesToDevice(
171-
nbytes(), src.raw_data(), raw_mutable_data(), GetDeviceType());
162+
CreateContext(src.GetDevice())
163+
->CopyBytesToDevice(
164+
nbytes(),
165+
src.raw_data(),
166+
raw_mutable_data(),
167+
GetDeviceType());
172168
} else {
173169
CAFFE_ENFORCE(
174170
context->device_type() == src.GetDeviceType(),
@@ -180,8 +176,8 @@ class CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
180176
// In case source context is CPU, and target context is non-CPU
181177
// We'll have to create a Context from target and perform the
182178
// copy using that context
183-
CreateContext()->CopyBytesFromCPU(
184-
nbytes(), src.raw_data(), raw_mutable_data());
179+
CreateContext(GetDevice())
180+
->CopyBytesFromCPU(nbytes(), src.raw_data(), raw_mutable_data());
185181
}
186182
}
187183
}

caffe2/ideep/utils/ideep_context.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class IDEEPContext final : public BaseContext {
2020
: RandomNumberSeed()) {
2121
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_IDEEP);
2222
}
23+
explicit IDEEPContext(const at::Device& device)
24+
: IDEEPContext(DeviceToOption(device)) {}
2325

2426
~IDEEPContext() noexcept override {}
2527

@@ -178,15 +180,6 @@ class IDEEPStaticContext : public BaseStaticContext {
178180
return GetCPUAllocator()->New(nbytes);
179181
}
180182

181-
std::unique_ptr<BaseContext> CreateContext() override {
182-
return caffe2::make_unique<IDEEPContext>();
183-
}
184-
185-
std::unique_ptr<BaseContext> CreateContext(
186-
const DeviceOption& option) override {
187-
return caffe2::make_unique<IDEEPContext>(option);
188-
}
189-
190183
DeviceType GetDeviceType() override {
191184
return IDEEP;
192185
}

caffe2/ideep/utils/ideep_register.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include <ideep_pin_singletons.hpp>
55
#include "ideep_context.h"
66

7+
namespace at {
8+
REGISTER_CONTEXT(DeviceType::IDEEP, caffe2::IDEEPContext);
9+
} // namespace at
710
namespace caffe2 {
811

912
CAFFE_KNOWN_TYPE(ideep::tensor);

caffe2/mkl/utils/mkl_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include "mkl_context.h"
44
#include "caffe2/core/event_cpu.h"
55

6+
namespace at {
7+
8+
REGISTER_CONTEXT(DeviceType::MKLDNN, caffe2::MKLContext);
9+
} // namespace at
610
namespace caffe2 {
711

812
// MKL events are the same as CPU events

0 commit comments

Comments
 (0)