Skip to content

Commit 3ae6ee4

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Move CreateContext to global registry (pytorch#11688)
Summary: Pull Request resolved: pytorch#11688 As a first step to remove static context(merge with allocator), we'll create a global registries for context constructors, and remove CreateContext function from tensor. Reviewed By: ezyang, dzhulgakov Differential Revision: D9779821 fbshipit-source-id: 8b239ea50af7a0556fde2382f58f79194f0e3dc1
1 parent b7c302d commit 3ae6ee4

21 files changed

+122
-94
lines changed

aten/src/ATen/core/context_base.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#include <ATen/core/context_base.h>
22

3+
namespace at {
4+
5+
AT_DEFINE_TYPED_REGISTRY(
6+
ContextRegistry,
7+
DeviceType,
8+
BaseContext,
9+
std::unique_ptr,
10+
at::Device);
11+
12+
} // namespace at
13+
314
namespace caffe2 {
415

516
// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h

aten/src/ATen/core/context_base.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
#include <memory>
77
#include <unordered_map>
88

9-
#include <ATen/core/DeviceType.h>
9+
#include <ATen/core/ATenGeneral.h>
10+
#include <ATen/core/Device.h>
1011
#include <ATen/core/Error.h>
12+
#include <ATen/core/Registry.h>
1113
#include <ATen/core/UniqueVoidPtr.h>
1214
#include <ATen/core/typeid.h>
13-
#include <ATen/core/ATenGeneral.h>
1415

1516
namespace caffe2 {
1617
class Event;
@@ -31,11 +32,6 @@ class AT_CORE_API BaseStaticContext {
3132

3233
virtual std::pair<void*, DeleterFnPtr> New(size_t nbytes) const = 0;
3334

34-
virtual std::unique_ptr<BaseContext> CreateContext() = 0;
35-
36-
virtual std::unique_ptr<BaseContext> CreateContext(
37-
const caffe2::DeviceOption&) = 0;
38-
3935
virtual DeviceType GetDeviceType() = 0;
4036

4137
/*
@@ -184,6 +180,22 @@ class AT_CORE_API BaseContext {
184180
}
185181
};
186182

183+
// Context constructor registry
184+
AT_DECLARE_TYPED_REGISTRY(
185+
ContextRegistry,
186+
at::DeviceType,
187+
BaseContext,
188+
std::unique_ptr,
189+
at::Device);
190+
191+
#define REGISTER_CONTEXT(type, ...) \
192+
AT_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)
193+
194+
inline std::unique_ptr<at::BaseContext> CreateContext(
195+
const at::Device& device) {
196+
return ContextRegistry()->Create(device.type(), device);
197+
}
198+
187199
} // namespace at
188200

189201
namespace caffe2 {

caffe2/core/blob_serialization.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void TensorSerializer::Serialize(
196196
const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
197197
proto.set_data_type(data_type);
198198
StoreDeviceDetail(input, &proto);
199-
auto uniq_ptr = input.GetStaticContext()->CreateContext();
199+
auto uniq_ptr = CreateContext(input.GetDevice());
200200
// A lot of copypaste is error prone. Should we create a macro for this?
201201
switch (data_type) {
202202
case TensorProto_DataType_FLOAT:
@@ -370,8 +370,7 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
370370
void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
371371
// We create a local context for deserializing. Since Caffe2 contexts are
372372
// usually lightweight, this should not involve too much overhead.
373-
auto uniq_ptr =
374-
tensor->GetStaticContext()->CreateContext(proto.device_detail());
373+
auto uniq_ptr = CreateContext(OptionToDevice(proto.device_detail()));
375374
auto context = uniq_ptr.get();
376375
context->SwitchToDevice(0);
377376
vector<int64_t> dims;

caffe2/core/context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <process.h>
66
#endif
77

8+
namespace at {
9+
10+
REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
11+
} // namespace at
812
namespace caffe2 {
913

1014
uint32_t RandomNumberSeed() {

caffe2/core/context.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class CAFFE2_API CPUContext final : public BaseContext {
5050
: RandomNumberSeed()) {
5151
CAFFE_ENFORCE_EQ(option.device_type(), PROTO_CPU);
5252
}
53+
explicit CPUContext(const at::Device& device)
54+
: CPUContext(DeviceToOption(device)) {}
5355

5456
~CPUContext() noexcept override {}
5557

@@ -192,15 +194,6 @@ class CAFFE2_API CPUStaticContext : public BaseStaticContext {
192194
return data_and_deleter;
193195
}
194196

195-
std::unique_ptr<BaseContext> CreateContext() override {
196-
return caffe2::make_unique<CPUContext>();
197-
}
198-
199-
std::unique_ptr<BaseContext> CreateContext(
200-
const DeviceOption& option) override {
201-
return caffe2::make_unique<CPUContext>(option);
202-
}
203-
204197
DeviceType GetDeviceType() override {
205198
return CPU;
206199
}

caffe2/core/context_base.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "context_base.h"
22

33
namespace caffe2 {
4+
45
} // namespace caffe2

caffe2/core/context_base.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
#include "caffe2/core/common.h"
66
#include "caffe2/core/logging.h"
77
#include "caffe2/proto/caffe2_pb.h"
8+
9+
namespace caffe2 {} // namespace caffe2

caffe2/core/context_gpu.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ CAFFE2_DEFINE_int(
5757
128,
5858
"The threshold in MB on how frequently to report memory changes");
5959

60+
namespace at {
61+
62+
REGISTER_CONTEXT(DeviceType::CUDA, caffe2::CUDAContext);
63+
} // namespace at
64+
6065
namespace caffe2 {
6166

6267
ThreadLocalCUDAObjects& CUDAContext::getCudaObjects() {

caffe2/core/context_gpu.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
142142
// The default cuda context constructor.
143143
explicit CUDAContext(const int gpu_id = -1);
144144
explicit CUDAContext(const DeviceOption& option);
145+
explicit CUDAContext(const at::Device& device)
146+
: CUDAContext(DeviceToOption(device)) {}
145147

146148
~CUDAContext() override {
147149
if (curand_generator_) {
@@ -385,19 +387,6 @@ class CAFFE2_CUDA_API CUDAStaticContext final : public BaseStaticContext {
385387
public:
386388
std::pair<void*, MemoryDeleter> New(size_t nbytes) const override;
387389

388-
std::unique_ptr<BaseContext> CreateContext() override {
389-
return caffe2::make_unique<CUDAContext>();
390-
}
391-
392-
std::unique_ptr<BaseContext> CreateContext(
393-
const DeviceOption& option) override {
394-
return caffe2::make_unique<CUDAContext>(option);
395-
}
396-
397-
std::unique_ptr<BaseContext> CreateContext(int gpu_id = -1) {
398-
return caffe2::make_unique<CUDAContext>(gpu_id);
399-
}
400-
401390
DeviceType GetDeviceType() override {
402391
return CUDA;
403392
}

caffe2/core/hip/context_hip.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ CAFFE2_DEFINE_int(caffe2_gpu_memory_report_interval_mb,
5050
128,
5151
"The threshold in MB on how frequently to report memory changes");
5252

53+
namespace at {
54+
55+
REGISTER_CONTEXT(DeviceType::HIP, caffe2::HIPContext);
56+
} // namespace at
57+
5358
namespace caffe2 {
5459

5560
thread_local ThreadLocalHIPObjects HIPContext::hip_objects_;
@@ -408,13 +413,12 @@ void HIPStaticContext::Delete(void* ptr) {
408413
g_hip_device_affiliation.erase(it);
409414
break;
410415
}
411-
case HipMemoryPoolType::THC:
412-
{
413-
HIP_ENFORCE(g_thc_allocator->Free(ptr));
414-
if (FLAGS_caffe2_gpu_memory_tracking) {
415-
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
416-
}
417-
break;
416+
case HipMemoryPoolType::THC: {
417+
HIP_ENFORCE(g_thc_allocator->Free(ptr));
418+
if (FLAGS_caffe2_gpu_memory_tracking) {
419+
g_hip_device_affiliation.erase(g_hip_device_affiliation.find(ptr));
420+
}
421+
break;
418422
}
419423
}
420424
}

0 commit comments

Comments
 (0)