Skip to content

Commit 310e7a1

Browse files
ZolotukhinMMikhail Zolotukhin
authored and
Mikhail Zolotukhin
committed
Backport some changes from master. (pytorch#193)
1 parent f3b2a0e commit 310e7a1

File tree

6 files changed

+41
-39
lines changed

6 files changed

+41
-39
lines changed

torch/csrc/jit/tensorexpr/codegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
2929

3030
void RegisterCodeGenList::AddStmtFactoryMethod(
3131
const std::string& name,
32-
StmtFactoryMethod stmt_factory_method) {
32+
const StmtFactoryMethod& stmt_factory_method) {
3333
auto insert_ret =
3434
stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method));
3535
if (!insert_ret.second) {

torch/csrc/jit/tensorexpr/codegen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class RegisterCodeGenList {
138138
RegisterCodeGenList() {}
139139
TORCH_API void AddStmtFactoryMethod(
140140
const std::string& name,
141-
StmtFactoryMethod stmt_factory_method);
141+
const StmtFactoryMethod& stmt_factory_method);
142142
RegisterCodeGenList(const RegisterCodeGenList&) = delete;
143143
RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete;
144144

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,7 @@ void TensorExprKernel::bindInput(const torch::jit::Value* input) {
10261026
}
10271027

10281028
TensorExprKernel::TensorExprKernel(const Graph& subgraph) {
1029-
KernelScope kernel_scope(kernel_arena_);
1029+
KernelScope kernel_scope(&kernel_arena_);
10301030

10311031
// Bind inputs to buffers.
10321032
n_inputs_ = subgraph.inputs().size();
@@ -1056,7 +1056,7 @@ TensorExprKernel::TensorExprKernel(const Graph& subgraph) {
10561056
}
10571057

10581058
void TensorExprKernel::run(Stack& stack) {
1059-
KernelScope kernel_scope(kernel_arena_);
1059+
KernelScope kernel_scope(&kernel_arena_);
10601060
// Set up arguments (inputs, then outputs) for kernel call.
10611061
auto inputs = last(stack, n_inputs_);
10621062
PickAndCheckBackendType(inputs);

torch/csrc/jit/tensorexpr/mem_arena.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,54 @@
1-
#include <stdexcept>
21
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
32

43
namespace torch {
54
namespace jit {
65
namespace tensorexpr {
76

7+
namespace {
8+
// Define in an anonymous namespace to hide this symbol from other compilation
9+
// units
10+
thread_local KernelArena* current_arena = nullptr;
11+
}
12+
813
KernelArena::~KernelArena() {
914
for (KernelScopedObject* p : kernel_objects_) {
1015
delete p;
1116
}
1217
}
1318

1419
KernelScopedObject::KernelScopedObject() {
15-
KernelArena& kernel = KernelArena::GetCurrentKernelArena();
16-
kernel.kernel_objects_.push_back(this);
20+
KernelArena* kernel = KernelArena::GetCurrentKernelArena();
21+
kernel->kernel_objects_.push_back(this);
1722
}
1823

1924
static std::vector<KernelArena*>& GetKernelArenaStack() {
2025
thread_local std::vector<KernelArena*> kernel_arena_stack;
2126
return kernel_arena_stack;
2227
}
2328

24-
KernelArena& KernelArena::GetCurrentKernelArena() {
25-
std::vector<KernelArena*>& kernel_arena_stack = GetKernelArenaStack();
26-
if (kernel_arena_stack.empty()) {
27-
throw std::runtime_error(
28-
"A KernelScope must be bound before creating KernelScopedObject");
29-
}
30-
return *kernel_arena_stack.back();
29+
void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) {
30+
current_arena = new_kernel_arena;
3131
}
3232

33-
KernelScope::KernelScope() : owning_kernel_arena_(true) {
34-
kernel_arena_ = new KernelArena;
35-
GetKernelArenaStack().push_back(kernel_arena_);
33+
KernelArena* KernelArena::GetCurrentKernelArena() {
34+
return current_arena;
3635
}
3736

38-
KernelScope::KernelScope(KernelArena& kernel_arena)
39-
: owning_kernel_arena_(false) {
40-
kernel_arena_ = &kernel_arena;
41-
GetKernelArenaStack().push_back(&kernel_arena);
37+
KernelScope::KernelScope() : owning_(true) {
38+
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
39+
KernelArena::SetCurrentKernelArena(new KernelArena);
4240
}
4341

44-
KernelScope::~KernelScope() noexcept(false) {
45-
std::vector<KernelArena*>& kernel_arena_stack = GetKernelArenaStack();
46-
if (kernel_arena_ != kernel_arena_stack.back()) {
47-
throw std::runtime_error("Mismatch KernelScope and kernel");
48-
}
49-
if (owning_kernel_arena_) {
50-
delete kernel_arena_;
42+
KernelScope::KernelScope(KernelArena* arena_) : owning_(false) {
43+
old_kernel_arena_ = KernelArena::GetCurrentKernelArena();
44+
KernelArena::SetCurrentKernelArena(arena_);
45+
}
46+
47+
KernelScope::~KernelScope() {
48+
if (owning_) {
49+
delete KernelArena::GetCurrentKernelArena();
5150
}
52-
kernel_arena_stack.pop_back();
51+
KernelArena::SetCurrentKernelArena(old_kernel_arena_);
5352
}
5453

5554
} // namespace tensorexpr

torch/csrc/jit/tensorexpr/mem_arena.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class KernelScopedObject;
1111
// An arena that manages all the underlying kernel-scoped objects.
1212
class KernelArena {
1313
public:
14-
static KernelArena& GetCurrentKernelArena();
14+
static KernelArena* GetCurrentKernelArena();
15+
static void SetCurrentKernelArena(KernelArena* new_arena);
1516
TORCH_API KernelArena() {}
1617
TORCH_API ~KernelArena();
1718

@@ -23,20 +24,23 @@ class KernelArena {
2324
};
2425

2526
// A RAII convenience wrapper on top of a kernel.
26-
// It either creates a Kernel, or take another existing Kernel, and sets it as
27-
// the current Kernel, as long as this KernelScope object is alive.
27+
// It either creates or takes an existing Kernel and sets it as the current
28+
// Kernel. When this object is destroyed, the previous Kernel is set as current,
29+
// and the created kernel is freed. If the kernel was passed, it stays alive.
2830
class KernelScope {
2931
public:
3032
TORCH_API KernelScope();
31-
TORCH_API explicit KernelScope(KernelArena& kernel_arena);
32-
TORCH_API ~KernelScope() noexcept(false);
33+
TORCH_API explicit KernelScope(KernelArena* arena_);
34+
TORCH_API ~KernelScope();
3335

3436
private:
3537
KernelScope(const KernelScope&) = delete;
3638
KernelScope& operator=(const KernelScope&) = delete;
37-
bool owning_kernel_arena_ = false;
38-
KernelArena* kernel_arena_ =
39-
nullptr; // possibly owned, if owning_kernel_arena_ == true
39+
KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope
40+
KernelArena* old_kernel_arena_ =
41+
nullptr; // previous arena, will be restored in destructor
42+
bool owning_ = false; // determines whether the arena will be freed along with
43+
// the scope object
4044
};
4145

4246
// The base object managed by the Kernel.
@@ -55,4 +59,3 @@ class TORCH_API KernelScopedObject {
5559
} // namespace tensorexpr
5660
} // namespace jit
5761
} // namespace torch
58-

torch/csrc/jit/tensorexpr/unique_name_manager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const std::string& UniqueNameManager::get_unique_name(const Var* v) {
2323
name_hint = "v" + name_hint;
2424
}
2525
int& count = unique_name_count_[name_hint];
26-
while (1) {
26+
while (true) {
2727
// Even if with a new count, this name might already be used. For example
2828
// ("x", 1) could collidewith ("x_1", 0)
2929
int count_v = count++;

0 commit comments

Comments
 (0)