@@ -21,34 +21,34 @@ using namespace torch::jit::tensorexpr::schedule;
21
21
22
22
void testExprSimple01 () {
23
23
KernelScope kernel_scope;
24
- Tensor tensor =
24
+ Tensor* tensor =
25
25
Compute (" f" , {{16 , " X" }, {5 , " y" }}, [](const Var& x, const Var& y) {
26
26
return Expr (1 .0f ) + cast<float >(x) * x + cast<float >(y) * y;
27
27
});
28
- Var x = tensor. function ()->arg (0 );
29
- Var y = tensor. function ()->arg (1 );
28
+ Var x = tensor-> function ()->arg (0 );
29
+ Var y = tensor-> function ()->arg (1 );
30
30
Schedule sch = Schedule::make ({tensor});
31
31
Var x_outer;
32
32
Var x_inner;
33
33
Var x_tail;
34
- TensorOperation tail_op;
35
- tensor. SplitWithTail (x, 2 , true , &x_outer, &x_inner, &x_tail, &tail_op);
34
+ TensorOperation* tail_op;
35
+ tensor-> SplitWithTail (x, 2 , true , &x_outer, &x_inner, &x_tail, &tail_op);
36
36
37
37
Var x_2;
38
38
Var x_1;
39
39
Var x_tail_2;
40
- TensorOperation tail_op_2;
41
- tensor. SplitWithTail (x_outer, 2 , true , &x_2, &x_1, &x_tail_2, &tail_op_2);
40
+ TensorOperation* tail_op_2;
41
+ tensor-> SplitWithTail (x_outer, 2 , true , &x_2, &x_1, &x_tail_2, &tail_op_2);
42
42
}
43
43
44
44
void testExprLower01 () {
45
45
KernelScope kernel_scope;
46
- Tensor tensor =
46
+ Tensor* tensor =
47
47
Compute (" f" , {{16 , " x" }, {5 , " y" }}, [](const Var& x, const Var& y) {
48
48
return Expr (1 .0f ) + cast<float >(x) * x + cast<float >(y) * y;
49
49
});
50
- Var x = tensor. function ()->arg (0 );
51
- Var y = tensor. function ()->arg (1 );
50
+ Var x = tensor-> function ()->arg (0 );
51
+ Var y = tensor-> function ()->arg (1 );
52
52
Schedule sch = Schedule::make ({tensor});
53
53
Stmt stmt = sch.Lower ();
54
54
std::ostringstream oss;
@@ -62,15 +62,15 @@ void testExprSimple02() {
62
62
auto func = [](const Expr& x, const Expr& y) {
63
63
return Expr (1 .0f ) + cast<float >(x) * x + cast<float >(y) * y;
64
64
};
65
- Tensor tensor = Compute (" f" , {{26 , " x" }, {5 , " y" }}, func);
66
- Var x = tensor. function ()->arg (0 );
67
- Var y = tensor. function ()->arg (1 );
65
+ Tensor* tensor = Compute (" f" , {{26 , " x" }, {5 , " y" }}, func);
66
+ Var x = tensor-> function ()->arg (0 );
67
+ Var y = tensor-> function ()->arg (1 );
68
68
Schedule sch = Schedule::make ({tensor});
69
69
Var x_outer;
70
70
Var x_inner;
71
71
Var x_tail;
72
- TensorOperation tail_op;
73
- tensor. SplitWithTail (x, 4 , true , &x_outer, &x_inner, &x_tail, &tail_op);
72
+ TensorOperation* tail_op;
73
+ tensor-> SplitWithTail (x, 4 , true , &x_outer, &x_inner, &x_tail, &tail_op);
74
74
75
75
Stmt stmt = sch.Lower ();
76
76
std::ostringstream oss;
@@ -132,17 +132,17 @@ void testExprSplitWithMask01() {
132
132
const int N = 5 ;
133
133
Buffer a_buf (" a" , kFloat32 , {M, N});
134
134
Buffer b_buf (" b" , kFloat32 , {M, N});
135
- Tensor tensor =
135
+ Tensor* tensor =
136
136
Compute (" f" , {{M, " m" }, {N, " n" }}, [&](const Expr& m, const Expr& n) {
137
137
return a_buf (m, n) + b_buf (m, n) + 1 .0f ;
138
138
});
139
- Var m = tensor. function ()->arg (0 );
140
- Var n = tensor. function ()->arg (1 );
139
+ Var m = tensor-> function ()->arg (0 );
140
+ Var n = tensor-> function ()->arg (1 );
141
141
Var n_outer;
142
142
Var n_inner;
143
143
144
144
Schedule sch ({tensor});
145
- tensor. SplitWithMask (n, 4 , true , &n_outer, &n_inner);
145
+ tensor-> SplitWithMask (n, 4 , true , &n_outer, &n_inner);
146
146
147
147
Stmt stmt = sch.Lower ();
148
148
@@ -170,7 +170,7 @@ void testScheduleBroadcastAddBuffer() {
170
170
const int K = 6 ;
171
171
Buffer a_buf (" a" , kFloat32 , {M, N});
172
172
Buffer b_buf (" b" , kFloat32 , {N, K});
173
- Tensor c = Compute (
173
+ Tensor* c = Compute (
174
174
" broadcast_add" ,
175
175
{{M, " m" }, {N, " n" }, {K, " k" }},
176
176
[&](const Var& m, const Var& n, const Var& k) {
@@ -219,16 +219,16 @@ void testScheduleFunctionCall01() {
219
219
const int K = 6 ;
220
220
Buffer a_buf (" a" , kFloat32 , {M, N});
221
221
Buffer b_buf (" b" , kFloat32 , {N, K});
222
- Tensor c = Compute (
222
+ Tensor* c = Compute (
223
223
" broadcast_add" ,
224
224
{{M, " m" }, {N, " n" }, {K, " k" }},
225
225
[&](const Var& m, const Var& n, const Var& k) {
226
226
return a_buf (m, n) + b_buf (n, k);
227
227
});
228
- Tensor d = Compute (
228
+ Tensor* d = Compute (
229
229
" d" ,
230
230
{{M, " m" }, {N, " n" }, {K, " k" }},
231
- [&](const Var& m, const Var& n, const Var& k) { return c (m, n, k) + 1 ; });
231
+ [&](const Var& m, const Var& n, const Var& k) { return c-> call (m, n, k) + 1 ; });
232
232
233
233
Schedule sch ({d});
234
234
Stmt stmt = sch.Lower ();
@@ -283,31 +283,31 @@ void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
283
283
Buffer c_buf (" c" , kFloat32 , {M, N});
284
284
Buffer d_buf (" d" , kFloat32 , {M, K});
285
285
286
- Tensor x = Compute (
286
+ Tensor* x = Compute (
287
287
" x" ,
288
288
{{M, " m1" }, {N, " n1" }, {K, " k1" }},
289
289
[&](const Var& m, const Var& n, const Var& k) {
290
290
return a_buf (m, n) * b_buf (n, k);
291
291
});
292
- Tensor y = Compute (
292
+ Tensor* y = Compute (
293
293
" y" ,
294
294
{{M, " m2" }, {N, " n2" }, {K, " k2" }},
295
295
[&](const Var& m, const Var& n, const Var& k) {
296
- return c_buf (m, n) * d_buf (m, k) + x (m, n, k);
296
+ return c_buf (m, n) * d_buf (m, k) + x-> call (m, n, k);
297
297
});
298
- Tensor z = Compute (
298
+ Tensor* z = Compute (
299
299
" z" ,
300
300
{{M, " m3" }, {N, " n3" }, {K, " k3" }},
301
301
[&](const Var& m, const Var& n, const Var& k) {
302
- return x (m, n, k) + y (m, n, k);
302
+ return x-> call (m, n, k) + y-> call (m, n, k);
303
303
});
304
304
305
305
Schedule sch ({z});
306
306
for (const std::string& order : inline_order) {
307
307
if (order == " x" ) {
308
- x. ComputeInline ();
308
+ x-> ComputeInline ();
309
309
} else if (order == " y" ) {
310
- y. ComputeInline ();
310
+ y-> ComputeInline ();
311
311
} else {
312
312
throw std::runtime_error (" Invalid order: " + order);
313
313
}
@@ -361,7 +361,7 @@ void InlineFunc01Helper(const std::vector<std::string>& inline_order) {
361
361
}
362
362
363
363
if (inline_order.size () == 2 ) {
364
- Tensor z2 = Compute (
364
+ Tensor* z2 = Compute (
365
365
" z" ,
366
366
{{M, " m3" }, {N, " n3" }, {K, " k3" }},
367
367
[&](const Var& m, const Var& n, const Var& k) {
@@ -397,14 +397,14 @@ void testScheduleFuserStyle() {
397
397
Buffer a_buf (Var (" A" , kHandle ), kFloat32 , {Expr (kTotalSize )});
398
398
Var a = a_buf.data ();
399
399
400
- Tensor b =
400
+ Tensor* b =
401
401
Compute (" f" , {{kTotalSize , " i" }}, [&](const std::vector<Var>& axes) {
402
402
return a_buf (axes[0 ]) + 11 .0f ;
403
403
});
404
404
405
- Tensor c =
405
+ Tensor* c =
406
406
Compute (" g" , {{kTotalSize , " i" }}, [&](const std::vector<Var>& axes) {
407
- return b (axes[0 ]) + 1 .0f ;
407
+ return b-> call (axes[0 ]) + 1 .0f ;
408
408
});
409
409
410
410
Schedule sch ({b, c});
@@ -432,16 +432,16 @@ void testScheduleFuserThreeArg() {
432
432
Buffer c (Var (" C" , kHandle ), kFloat32 , {Expr (kTotalSize )});
433
433
Buffer d (Var (" D" , kHandle ), kFloat32 , {Expr (kTotalSize )});
434
434
435
- Tensor e = Compute (
435
+ Tensor* e = Compute (
436
436
" e" , {{kTotalSize , " i" }}, [&](const Var& i) { return a (i) + b (i); });
437
- Tensor f = Compute (
438
- " f" , {{kTotalSize , " i" }}, [&](const Var& i) { return e (i) + c (i); });
439
- Tensor g = Compute (
440
- " g" , {{kTotalSize , " i" }}, [&](const Var& i) { return f (i) + d (i); });
437
+ Tensor* f = Compute (
438
+ " f" , {{kTotalSize , " i" }}, [&](const Var& i) { return (*e) (i) + c (i); });
439
+ Tensor* g = Compute (
440
+ " g" , {{kTotalSize , " i" }}, [&](const Var& i) { return (*f) (i) + d (i); });
441
441
442
442
Schedule sch ({g});
443
- e. ComputeInline ();
444
- f. ComputeInline ();
443
+ e-> ComputeInline ();
444
+ f-> ComputeInline ();
445
445
Stmt s = sch.Lower ();
446
446
447
447
std::vector<float > a_data (kTotalSize , 1 .0f );
@@ -463,7 +463,7 @@ void testScheduleDynamicShape2D() {
463
463
Var n (" n" , kInt32 );
464
464
Buffer a (Var (" a" , kHandle ), kFloat32 , {m, n});
465
465
Buffer b (Var (" b" , kHandle ), kFloat32 , {m, n});
466
- Tensor c =
466
+ Tensor* c =
467
467
Compute (" c" , {{m, " m" }, {n, " n" }}, [&](const Var& i, const Var& j) {
468
468
return a (i, j) + b (i, j);
469
469
});
0 commit comments