Skip to content

Commit e1240dc

Browse files
ezyangPenghuiCheng
authored andcommitted
Delete context and get_context from Type.
Summary: Pull Request resolved: pytorch#11001 Reviewed By: cpuhrsch Differential Revision: D9557315 fbshipit-source-id: b9862b8dda49194298bb1a4fbc214d466f3c8350
1 parent 2a79b56 commit e1240dc

File tree

11 files changed

+18
-21
lines changed

11 files changed

+18
-21
lines changed

aten/src/ATen/UndefinedType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
namespace at {
55

6-
UndefinedType::UndefinedType(Context* context)
7-
: Type(context, UndefinedTensorId(), /*is_variable=*/false, /*is_undefined=*/true) {}
6+
UndefinedType::UndefinedType()
7+
: Type(UndefinedTensorId(), /*is_variable=*/false, /*is_undefined=*/true) {}
88
ScalarType UndefinedType::scalarType() const {
99
return ScalarType::Undefined;
1010
}

aten/src/ATen/UndefinedType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace at {
1313

1414
struct UndefinedType final : public Type {
15-
explicit UndefinedType(Context* context);
15+
explicit UndefinedType();
1616
virtual ScalarType scalarType() const override;
1717
virtual Backend backend() const override;
1818
virtual bool is_cuda() const override;

aten/src/ATen/gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def check_all_files_written(self):
125125
TYPE_REGISTER = CodeTemplate("""\
126126
context->type_registry[static_cast<int>(Backend::${backend})]
127127
[static_cast<int>(ScalarType::${scalar_type})]
128-
.reset(new ${type_name}(context));
128+
.reset(new ${type_name}());
129129
detail::getVariableHooks().registerVariableTypeFor(context, Backend::${backend}, ScalarType::${scalar_type});
130130
""")
131131

aten/src/ATen/native/cuda/Gesv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void magmaGesvBatched<double>(
4848
}
4949

5050
static magma_queue_t createMagmaQueue(const Tensor& tensor) {
51-
auto& context = tensor.type().get_context();
51+
auto& context = at::globalContext();
5252
magma_queue_t magma_queue;
5353
magma_queue_create_from_cuda(
5454
tensor.get_device(),

aten/src/ATen/templates/RegisterCPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace at {
1414
void register_cpu_types(Context * context) {
1515
${cpu_type_registrations}
1616
context->type_registry[static_cast<int>(Backend::Undefined)]
17-
[static_cast<int>(ScalarType::Undefined)].reset(new UndefinedType(context));
17+
[static_cast<int>(ScalarType::Undefined)].reset(new UndefinedType());
1818
}
1919

2020
} // namespace at

aten/src/ATen/templates/SparseTypeDerived.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
namespace at {
2929

30-
${Type}::${Type}(Context* context)
31-
: Type(context, ${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
30+
${Type}::${Type}()
31+
: Type(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
3232
ScalarType ${Type}::scalarType() const {
3333
return ScalarType::${ScalarName};
3434
}
@@ -58,7 +58,7 @@ Storage ${Type}::unsafeStorageFromTH(void * th_pointer, bool retain) const {
5858
AT_ERROR("unsafeTensorFromTH not supported on sparse");
5959
}
6060
std::unique_ptr<Generator> ${Type}::generator() const {
61-
return std::unique_ptr<Generator>(new ${Generator}(context));
61+
return std::unique_ptr<Generator>(new ${Generator}(&at::globalContext()));
6262
}
6363

6464
const char * ${Type}::toString() const {

aten/src/ATen/templates/Type.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ Tensor Type::copy(const Tensor & src, bool non_blocking) const {
3838
}
3939

4040
Type & Type::toBackend(Backend b) const {
41-
return context->getType(b,scalarType());
41+
return at::globalContext().getType(b,scalarType());
4242
}
4343
Type & Type::toScalarType(ScalarType s) const {
44-
return context->getType(backend(),s);
44+
return at::globalContext().getType(backend(),s);
4545
}
4646
static std::vector<int64_t> defaultStrides(IntList sizes) {
4747
std::vector<int64_t> strides(sizes.size());

aten/src/ATen/templates/Type.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ enum class TypeID {
4545
};
4646

4747
struct AT_API Type {
48-
explicit Type(Context* context, TensorTypeId type_id, bool is_variable, bool is_undefined)
49-
: context(context), type_id_(type_id), is_variable_(is_variable), is_undefined_(is_undefined) {}
48+
explicit Type(TensorTypeId type_id, bool is_variable, bool is_undefined)
49+
: type_id_(type_id), is_variable_(is_variable), is_undefined_(is_undefined) {}
5050
virtual ~Type() {}
5151
virtual ScalarType scalarType() const = 0;
5252
virtual Backend backend() const = 0;
@@ -79,8 +79,6 @@ struct AT_API Type {
7979
Type & cuda() const {
8080
return this->toBackend(at::backendToCUDA(this->backend()));
8181
}
82-
Context& get_context() const { return *context; }
83-
8482
// contiguous IDs for all types in the system
8583
// for external dispatch
8684
virtual TypeID ID() const = 0;
@@ -111,7 +109,6 @@ struct AT_API Type {
111109
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
112110
${type_method_declarations}
113111
protected:
114-
Context* context;
115112
TensorTypeId type_id_;
116113
bool is_variable_;
117114
bool is_undefined_;

aten/src/ATen/templates/TypeDerived.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ static int getPointerDevice(void* ptr) {
3838
}
3939
#endif
4040

41-
${Type}::${Type}(Context* context)
42-
: Type(context, ${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
41+
${Type}::${Type}()
42+
: Type(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
4343
ScalarType ${Type}::scalarType() const {
4444
return ScalarType::${ScalarName};
4545
}
@@ -99,7 +99,7 @@ Storage ${Type}::unsafeStorageFromTH(void * th_pointer, bool retain) const {
9999
return Storage((${THStorage}*) th_pointer);
100100
}
101101
std::unique_ptr<Generator> ${Type}::generator() const {
102-
return std::unique_ptr<Generator>(new ${Generator}(context));
102+
return std::unique_ptr<Generator>(new ${Generator}(&at::globalContext()));
103103
}
104104

105105
const char * ${Type}::toString() const {

aten/src/ATen/templates/TypeDerived.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
namespace at {
1717

1818
struct ${Type} final : public Type {
19-
explicit ${Type}(Context* context);
19+
explicit ${Type}();
2020
virtual ScalarType scalarType() const override;
2121
virtual Backend backend() const override;
2222
virtual bool is_cuda() const override;

tools/autograd/templates/VariableType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using namespace torch::autograd::generated;
4343
namespace torch { namespace autograd {
4444

4545
VariableType::VariableType(Context* context, Type* baseType)
46-
: Type(context, baseType->type_id(), /*is_variable=*/true, /*is_undefined=*/false)
46+
: Type(baseType->type_id(), /*is_variable=*/true, /*is_undefined=*/false)
4747
, baseType(baseType)
4848
, id_(context->freshTypeID()) {
4949
str = std::string("Variable[") + baseType->toString() + "]";

0 commit comments

Comments
 (0)