Skip to content

Commit 58d5cc0

Browse files
ZolotukhinMMikhail Zolotukhin
authored and
Mikhail Zolotukhin
committed
Remove TensorNode and TensorOperationNode classes and remove some wrapper accessors to make the code more explicit. (pytorch#186)
* Remove wrapper function accessors from TensorNode: instead access function_'s members directly through function(). * Remove TensorNode class. * Remove TensorOperationNode class.
1 parent debbd4b commit 58d5cc0

File tree

10 files changed

+182
-313
lines changed

10 files changed

+182
-313
lines changed

test/cpp/tensorexpr/test_llvm.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -797,11 +797,11 @@ void testLLVMStoreFloat() {
797797
void testLLVMSimpleMath01() {
798798
KernelScope kernel_scope;
799799
const int N = 1024;
800-
Tensor tensor = Compute(
800+
Tensor* tensor = Compute(
801801
"f", {{N, "i"}}, [](const Var& i) { return cast<float>(i * i + 1); });
802802
Schedule sch = Schedule::make({tensor});
803803
Stmt stmt = sch.Lower();
804-
Buffer f_buf(tensor.function()->func_var(), kFloat32, {N});
804+
Buffer f_buf(tensor->function()->func_var(), kFloat32, {N});
805805
LLVMCodeGen cg(stmt, {f_buf});
806806

807807
PaddedBuffer<float> f_v(N, "f_v");
@@ -820,11 +820,11 @@ void testLLVMComputeMul() {
820820
const int N = 1024;
821821
Buffer a(Var("a", kHandle), kFloat32, {N});
822822
Buffer b(Var("b", kHandle), kFloat32, {N});
823-
Tensor c = Compute("c", {{N, "i"}}, [&](const Var& i) {
823+
Tensor* c = Compute("c", {{N, "i"}}, [&](const Var& i) {
824824
return Load::make(a, i, 1) * Load::make(b, i, 1);
825825
});
826826

827-
Buffer c_buf(c.function()->func_var(), kFloat32, {N});
827+
Buffer c_buf(c->function()->func_var(), kFloat32, {N});
828828
Schedule sch = Schedule::make({c});
829829
Stmt s = sch.Lower();
830830

@@ -844,13 +844,13 @@ void testLLVMBroadcastAdd() {
844844
const int N = 1024;
845845
Buffer a(Var("a", kHandle), kFloat32, {M, N});
846846
Buffer b(Var("b", kHandle), kFloat32, {N});
847-
Tensor c =
847+
Tensor* c =
848848
Compute("c", {{M, "i"}, {N, "j"}}, [&](const Var& i, const Var& j) {
849849
Expr mask(1);
850850
return Load::make(a, i * N + j, mask) + Load::make(b, j, mask);
851851
});
852852

853-
Buffer c_buf(c.function()->func_var(), kFloat32, {M, N});
853+
Buffer c_buf(c->function()->func_var(), kFloat32, {M, N});
854854
Schedule sch = Schedule::make({c});
855855
Stmt s = sch.Lower();
856856

@@ -920,7 +920,7 @@ void testLLVMTensorDynamicShapeAdd() {
920920
Var n("n", kInt32);
921921
Buffer a(Var("a", kHandle), kFloat32, {n});
922922
Buffer b(Var("b", kHandle), kFloat32, {n});
923-
Tensor c =
923+
Tensor* c =
924924
Compute("c", {{n, "n"}}, [&](const Var& i) { return a(i) + b(i); });
925925
Schedule sch = Schedule::make({c});
926926
Stmt s = sch.Lower();
@@ -943,7 +943,7 @@ void testLLVMDynamicShape2D() {
943943
Var n("n", kInt32);
944944
Buffer a(Var("a", kHandle), kFloat32, {m, n});
945945
Buffer b(Var("b", kHandle), kFloat32, {m, n});
946-
Tensor c =
946+
Tensor* c =
947947
Compute("c", {{m, "m"}, {n, "n"}}, [&](const Var& i, const Var& j) {
948948
return a(i, j) + b(i, j);
949949
});

test/cpp/tensorexpr/test_schedule.cpp

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,34 @@ using namespace torch::jit::tensorexpr::schedule;
2121

2222
void testExprSimple01() {
2323
KernelScope kernel_scope;
24-
Tensor tensor =
24+
Tensor* tensor =
2525
Compute("f", {{16, "X"}, {5, "y"}}, [](const Var& x, const Var& y) {
2626
return Expr(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
2727
});
28-
Var x = tensor.function()->arg(0);
29-
Var y = tensor.function()->arg(1);
28+
Var x = tensor->function()->arg(0);
29+
Var y = tensor->function()->arg(1);
3030
Schedule sch = Schedule::make({tensor});
3131
Var x_outer;
3232
Var x_inner;
3333
Var x_tail;
34-
TensorOperation tail_op;
35-
tensor.SplitWithTail(x, 2, true, &x_outer, &x_inner, &x_tail, &tail_op);
34+
TensorOperation* tail_op;
35+
tensor->SplitWithTail(x, 2, true, &x_outer, &x_inner, &x_tail, &tail_op);
3636

3737
Var x_2;
3838
Var x_1;
3939
Var x_tail_2;
40-
TensorOperation tail_op_2;
41-
tensor.SplitWithTail(x_outer, 2, true, &x_2, &x_1, &x_tail_2, &tail_op_2);
40+
TensorOperation* tail_op_2;
41+
tensor->SplitWithTail(x_outer, 2, true, &x_2, &x_1, &x_tail_2, &tail_op_2);
4242
}
4343

4444
void testExprLower01() {
4545
KernelScope kernel_scope;
46-
Tensor tensor =
46+
Tensor* tensor =
4747
Compute("f", {{16, "x"}, {5, "y"}}, [](const Var& x, const Var& y) {
4848
return Expr(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
4949
});
50-
Var x = tensor.function()->arg(0);
51-
Var y = tensor.function()->arg(1);
50+
Var x = tensor->function()->arg(0);
51+
Var y = tensor->function()->arg(1);
5252
Schedule sch = Schedule::make({tensor});
5353
Stmt stmt = sch.Lower();
5454
std::ostringstream oss;
@@ -62,15 +62,15 @@ void testExprSimple02() {
6262
auto func = [](const Expr& x, const Expr& y) {
6363
return Expr(1.0f) + cast<float>(x) * x + cast<float>(y) * y;
6464
};
65-
Tensor tensor = Compute("f", {{26, "x"}, {5, "y"}}, func);
66-
Var x = tensor.function()->arg(0);
67-
Var y = tensor.function()->arg(1);
65+
Tensor* tensor = Compute("f", {{26, "x"}, {5, "y"}}, func);
66+
Var x = tensor->function()->arg(0);
67+
Var y = tensor->function()->arg(1);
6868
Schedule sch = Schedule::make({tensor});
6969
Var x_outer;
7070
Var x_inner;
7171
Var x_tail;
72-
TensorOperation tail_op;
73-
tensor.SplitWithTail(x, 4, true, &x_outer, &x_inner, &x_tail, &tail_op);
72+
TensorOperation* tail_op;
73+
tensor->SplitWithTail(x, 4, true, &x_outer, &x_inner, &x_tail, &tail_op);
7474

7575
Stmt stmt = sch.Lower();
7676
std::ostringstream oss;
@@ -132,17 +132,17 @@ void testExprSplitWithMask01() {
132132
const int N = 5;
133133
Buffer a_buf("a", kFloat32, {M, N});
134134
Buffer b_buf("b", kFloat32, {M, N});
135-
Tensor tensor =
135+
Tensor* tensor =
136136
Compute("f", {{M, "m"}, {N, "n"}}, [&](const Expr& m, const Expr& n) {
137137
return a_buf(m, n) + b_buf(m, n) + 1.0f;
138138
});
139-
Var m = tensor.function()->arg(0);
140-
Var n = tensor.function()->arg(1);
139+
Var m = tensor->function()->arg(0);
140+
Var n = tensor->function()->arg(1);
141141
Var n_outer;
142142
Var n_inner;
143143

144144
Schedule sch({tensor});
145-
tensor.SplitWithMask(n, 4, true, &n_outer, &n_inner);
145+
tensor->SplitWithMask(n, 4, true, &n_outer, &n_inner);
146146

147147
Stmt stmt = sch.Lower();
148148

@@ -170,7 +170,7 @@ void testScheduleBroadcastAddBuffer() {
170170
const int K = 6;
171171
Buffer a_buf("a", kFloat32, {M, N});
172172
Buffer b_buf("b", kFloat32, {N, K});
173-
Tensor c = Compute(
173+
Tensor* c = Compute(
174174
"broadcast_add",
175175
{{M, "m"}, {N, "n"}, {K, "k"}},
176176
[&](const Var& m, const Var& n, const Var& k) {
@@ -219,16 +219,16 @@ void testScheduleFunctionCall01() {
219219
const int K = 6;
220220
Buffer a_buf("a", kFloat32, {M, N});
221221
Buffer b_buf("b", kFloat32, {N, K});
222-
Tensor c = Compute(
222+
Tensor* c = Compute(
223223
"broadcast_add",
224224
{{M, "m"}, {N, "n"}, {K, "k"}},
225225
[&](const Var& m, const Var& n, const Var& k) {
226226
return a_buf(m, n) + b_buf(n, k);
227227
});
228-
Tensor d = Compute(
228+
Tensor* d = Compute(
229229
"d",
230230
{{M, "m"}, {N, "n"}, {K, "k"}},
231-
[&](const Var& m, const Var& n, const Var& k) { return c(m, n, k) + 1; });
231+
[&](const Var& m, const Var& n, const Var& k) { return c->call(m, n, k) + 1; });
232232

233233
Schedule sch({d});
234234
Stmt stmt = sch.Lower();
@@ -283,31 +283,31 @@ void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
283283
Buffer c_buf("c", kFloat32, {M, N});
284284
Buffer d_buf("d", kFloat32, {M, K});
285285

286-
Tensor x = Compute(
286+
Tensor* x = Compute(
287287
"x",
288288
{{M, "m1"}, {N, "n1"}, {K, "k1"}},
289289
[&](const Var& m, const Var& n, const Var& k) {
290290
return a_buf(m, n) * b_buf(n, k);
291291
});
292-
Tensor y = Compute(
292+
Tensor* y = Compute(
293293
"y",
294294
{{M, "m2"}, {N, "n2"}, {K, "k2"}},
295295
[&](const Var& m, const Var& n, const Var& k) {
296-
return c_buf(m, n) * d_buf(m, k) + x(m, n, k);
296+
return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k);
297297
});
298-
Tensor z = Compute(
298+
Tensor* z = Compute(
299299
"z",
300300
{{M, "m3"}, {N, "n3"}, {K, "k3"}},
301301
[&](const Var& m, const Var& n, const Var& k) {
302-
return x(m, n, k) + y(m, n, k);
302+
return x->call(m, n, k) + y->call(m, n, k);
303303
});
304304

305305
Schedule sch({z});
306306
for (const std::string& order : inline_order) {
307307
if (order == "x") {
308-
x.ComputeInline();
308+
x->ComputeInline();
309309
} else if (order == "y") {
310-
y.ComputeInline();
310+
y->ComputeInline();
311311
} else {
312312
throw std::runtime_error("Invalid order: " + order);
313313
}
@@ -361,7 +361,7 @@ void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
361361
}
362362

363363
if (inline_order.size() == 2) {
364-
Tensor z2 = Compute(
364+
Tensor* z2 = Compute(
365365
"z",
366366
{{M, "m3"}, {N, "n3"}, {K, "k3"}},
367367
[&](const Var& m, const Var& n, const Var& k) {
@@ -397,14 +397,14 @@ void testScheduleFuserStyle() {
397397
Buffer a_buf(Var("A", kHandle), kFloat32, {Expr(kTotalSize)});
398398
Var a = a_buf.data();
399399

400-
Tensor b =
400+
Tensor* b =
401401
Compute("f", {{kTotalSize, "i"}}, [&](const std::vector<Var>& axes) {
402402
return a_buf(axes[0]) + 11.0f;
403403
});
404404

405-
Tensor c =
405+
Tensor* c =
406406
Compute("g", {{kTotalSize, "i"}}, [&](const std::vector<Var>& axes) {
407-
return b(axes[0]) + 1.0f;
407+
return b->call(axes[0]) + 1.0f;
408408
});
409409

410410
Schedule sch({b, c});
@@ -432,16 +432,16 @@ void testScheduleFuserThreeArg() {
432432
Buffer c(Var("C", kHandle), kFloat32, {Expr(kTotalSize)});
433433
Buffer d(Var("D", kHandle), kFloat32, {Expr(kTotalSize)});
434434

435-
Tensor e = Compute(
435+
Tensor* e = Compute(
436436
"e", {{kTotalSize, "i"}}, [&](const Var& i) { return a(i) + b(i); });
437-
Tensor f = Compute(
438-
"f", {{kTotalSize, "i"}}, [&](const Var& i) { return e(i) + c(i); });
439-
Tensor g = Compute(
440-
"g", {{kTotalSize, "i"}}, [&](const Var& i) { return f(i) + d(i); });
437+
Tensor* f = Compute(
438+
"f", {{kTotalSize, "i"}}, [&](const Var& i) { return (*e)(i) + c(i); });
439+
Tensor* g = Compute(
440+
"g", {{kTotalSize, "i"}}, [&](const Var& i) { return (*f)(i) + d(i); });
441441

442442
Schedule sch({g});
443-
e.ComputeInline();
444-
f.ComputeInline();
443+
e->ComputeInline();
444+
f->ComputeInline();
445445
Stmt s = sch.Lower();
446446

447447
std::vector<float> a_data(kTotalSize, 1.0f);
@@ -463,7 +463,7 @@ void testScheduleDynamicShape2D() {
463463
Var n("n", kInt32);
464464
Buffer a(Var("a", kHandle), kFloat32, {m, n});
465465
Buffer b(Var("b", kHandle), kFloat32, {m, n});
466-
Tensor c =
466+
Tensor* c =
467467
Compute("c", {{m, "m"}, {n, "n"}}, [&](const Var& i, const Var& j) {
468468
return a(i, j) + b(i, j);
469469
});

torch/csrc/jit/tensorexpr/codegen.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ class CodeGen::BufferArg {
5050
public:
5151
BufferArg(const Buffer& buffer)
5252
: var_(buffer.data()), dtype_(buffer.dtype()) {}
53-
BufferArg(const Tensor& tensor)
54-
: var_(tensor.function()->func_var()),
55-
dtype_(tensor.function()->body().dtype()) {}
53+
BufferArg(Tensor* tensor)
54+
: var_(tensor->function()->func_var()),
55+
dtype_(tensor->function()->body().dtype()) {}
5656
BufferArg(const Function& func)
5757
: var_(func.func_var()), dtype_(func.body().dtype()) {}
5858
BufferArg(const Var& var) : var_(var), dtype_(var.dtype()), isVar_(true) {}

torch/csrc/jit/tensorexpr/function.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ static void unpack_dim_args(
2323

2424
} // namespace
2525

26-
Tensor Compute(
26+
Tensor* Compute(
2727
const std::string& func_name,
2828
const std::vector<DimArg>& dim_args,
2929
std::function<Expr(const std::vector<Var>&)> body_func) {
@@ -33,10 +33,10 @@ Tensor Compute(
3333
Expr body = body_func(args);
3434
Function* func = new Function(
3535
func_name, std::move(dims), std::move(args), std::move(body));
36-
return Tensor(func, 0);
36+
return new Tensor(func, 0);
3737
}
3838

39-
Tensor Compute(
39+
Tensor* Compute(
4040
const std::string& func_name,
4141
const std::vector<DimArg>& dim_args,
4242
std::function<Expr(const Var&)> body_func) {
@@ -47,10 +47,10 @@ Tensor Compute(
4747
Expr body = body_func(args[0]);
4848
Function* func =
4949
new Function(func_name, std::move(dims), std::move(args), std::move(body));
50-
return Tensor(func, 0);
50+
return new Tensor(func, 0);
5151
}
5252

53-
Tensor Compute(
53+
Tensor* Compute(
5454
const std::string& func_name,
5555
const std::vector<DimArg>& dim_args,
5656
std::function<Expr(const Var&, const Var&)> body_func) {
@@ -61,10 +61,10 @@ Tensor Compute(
6161
Expr body = body_func(args[0], args[1]);
6262
Function* func = new Function(
6363
func_name, std::move(dims), std::move(args), std::move(body));
64-
return Tensor(func, 0);
64+
return new Tensor(func, 0);
6565
}
6666

67-
Tensor Compute(
67+
Tensor* Compute(
6868
const std::string& func_name,
6969
const std::vector<DimArg>& dim_args,
7070
std::function<Expr(const Var&, const Var&, const Var&)> body_func) {
@@ -75,10 +75,10 @@ Tensor Compute(
7575
Expr body = body_func(args[0], args[1], args[2]);
7676
Function* func = new Function(
7777
func_name, std::move(dims), std::move(args), std::move(body));
78-
return Tensor(func, 0);
78+
return new Tensor(func, 0);
7979
}
8080

81-
Tensor Compute(
81+
Tensor* Compute(
8282
const std::string& func_name,
8383
const std::vector<DimArg>& dim_args,
8484
std::function<Expr(const Var&, const Var&, const Var&, const Var&)>
@@ -90,7 +90,7 @@ Tensor Compute(
9090
Expr body = body_func(args[0], args[1], args[2], args[3]);
9191
Function* func = new Function(
9292
func_name, std::move(dims), std::move(args), std::move(body));
93-
return Tensor(func, 0);
93+
return new Tensor(func, 0);
9494
}
9595

9696
Stmt Function::ElementStmt() {

0 commit comments

Comments
 (0)