@@ -7,12 +7,10 @@ namespace torch {
7
7
namespace jit {
8
8
namespace fuser {
9
9
10
- // Will return a new value of type val with the DataType dtype, if it's a
11
- // tensorview it will propagate the shape information from val.
12
- TORCH_CUDA_API Val* newValLike (const Val* val, DataType dtype) {
13
- switch (val->getValType ().value ()) {
14
- case (ValType::TensorView):
15
- return val->as <TensorView>()->newForOutput (dtype);
10
+ namespace {
11
+ // Will return a new value of type val with the DataType dtype.
12
+ Val* newScalar (ValType vtype, DataType dtype) {
13
+ switch (vtype) {
16
14
case (ValType::NamedScalar):
17
15
case (ValType::Scalar):
18
16
switch (dtype) {
@@ -33,36 +31,92 @@ TORCH_CUDA_API Val* newValLike(const Val* val, DataType dtype) {
33
31
34
32
TORCH_CHECK (
35
33
false ,
36
- " Could not generate a new value of type " ,
37
- val->getValType ().value (),
38
- " with data type " ,
39
- val->getDataType ().value ());
34
+ " Was expecting a scalar type, but received ValType: " ,
35
+ vtype,
36
+ " with DataType:" ,
37
+ dtype);
38
+ }
39
+
40
+ TensorView* newOutputTV (const std::vector<Val*>& vals, DataType dtype) {
41
+ std::vector<TensorView*> tvs;
42
+ for (auto val : vals)
43
+ if (val->getValType () == ValType::TensorView)
44
+ tvs.push_back (static_cast <TensorView*>(val));
45
+
46
+ TORCH_CHECK (
47
+ !tvs.empty (),
48
+ " Tried to create new output TensorView but received empty list." );
49
+
50
+ std::vector<IterDomain*> out_domain (
51
+ tvs[0 ]->domain ()->noReductions ().size (), nullptr );
52
+
53
+ for (auto tv : tvs) {
54
+ auto dom = tv->domain ()->noReductions ();
55
+ TORCH_INTERNAL_ASSERT (
56
+ dom.size () == out_domain.size (),
57
+ " Invalid tensor view found while producing and output, it has " ,
58
+ dom.size (),
59
+ " dimensions but expected " ,
60
+ out_domain.size ());
61
+ for (size_t i = 0 ; i < dom.size (); i++) {
62
+ if (out_domain[i] != nullptr )
63
+ continue ;
64
+ if (dom[i]->isBroadcast ())
65
+ continue ;
66
+ out_domain[i] = new IterDomain (dom[i]->start (), dom[i]->extent ());
67
+ }
68
+ }
69
+
70
+ std::transform (
71
+ out_domain.begin (),
72
+ out_domain.end (),
73
+ out_domain.begin (),
74
+ [](IterDomain* dom) {
75
+ if (dom == nullptr )
76
+ return new IterDomain (
77
+ new Int (0 ), new Int (1 ), ParallelType::Serial, false , false , true );
78
+ return dom;
79
+ });
80
+
81
+ return new TensorView (new TensorDomain (out_domain), dtype);
40
82
}
41
83
42
- TORCH_CUDA_API Val* newValLike (const Val* val) {
43
- return newValLike (val, val->getDataType ().value ());
84
+ Val* newOutputVal (const std::vector<Val*>& vals) {
85
+ TORCH_INTERNAL_ASSERT (
86
+ !vals.empty (), " Cannot promote values if there aren't any." );
87
+
88
+ ValType out_vtype = vals[0 ]->getValType ().value ();
89
+ DataType out_dtype = vals[0 ]->getDataType ().value ();
90
+
91
+ for (auto val : vals) {
92
+ TORCH_CHECK (val->isVal (), " Invalid statement found during promotion." );
93
+ TORCH_CHECK (
94
+ val->getDataType ().value () != DataType::Null,
95
+ " Invalid datatype found during prmotion." );
96
+ out_vtype = promote_type (out_vtype, val->getValType ().value ());
97
+ out_dtype = promote_type (out_dtype, val->getDataType ().value ());
98
+ }
99
+
100
+ if (out_vtype == ValType::TensorView)
101
+ return newOutputTV (vals, out_dtype);
102
+
103
+ return newScalar (out_vtype, out_dtype);
44
104
}
45
105
46
- TORCH_CUDA_API Val* promoteNew (Val* v1, Val* v2) {
47
- // Can't promote two types if they aren't both
48
- // values with valid data types.
49
- TORCH_CHECK (v1->isVal () && v2->isVal ());
106
+ Val* newValLike (Val* val, DataType dtype) {
107
+ TORCH_CHECK (val->isVal (), " Invalid statement provided to create new value." );
50
108
TORCH_CHECK (
51
- v1->getDataType () != DataType::Null &&
52
- v2->getDataType () != DataType::Null);
109
+ dtype != DataType::Null, " Invalid datatype provided for new value." );
53
110
54
- ValType out_vtype =
55
- promote_type (v1->getValType ().value (), v2->getValType ().value ());
56
- DataType out_dtype =
57
- promote_type (v1->getDataType ().value (), v2->getDataType ().value ());
111
+ ValType vtype = val->getValType ().value ();
58
112
59
- if (out_vtype == v2-> getValType (). value () )
60
- return newValLike (v2, out_dtype );
113
+ if (vtype == ValType::TensorView )
114
+ return newOutputTV ({val}, dtype );
61
115
62
- return newValLike (v1, out_dtype );
116
+ return newScalar (vtype, dtype );
63
117
}
64
118
65
- Val* newConstScalar (DataType dtype, int val) {
119
+ Val* newConstScalar (DataType dtype, long int val) {
66
120
switch (dtype) {
67
121
case (DataType::Int):
68
122
return new Int (val);
@@ -77,7 +131,7 @@ Val* newConstScalar(DataType dtype, int val) {
77
131
val);
78
132
}
79
133
80
- Val* newConstScalar (DataType dtype, float val) {
134
+ Val* newConstScalar (DataType dtype, double val) {
81
135
switch (dtype) {
82
136
case (DataType::Float):
83
137
return new Float (val);
@@ -92,6 +146,8 @@ Val* newConstScalar(DataType dtype, float val) {
92
146
val);
93
147
}
94
148
149
+ } // namespace
150
+
95
151
TORCH_CUDA_API Val* castOp (DataType dtype, Val* v1) {
96
152
if (v1->getDataType ().value () == dtype)
97
153
return v1;
@@ -118,7 +174,7 @@ TORCH_CUDA_API TensorView* castOp(DataType dtype, TensorView* v1) {
118
174
// UNARY OPERATIONS
119
175
120
176
TORCH_CUDA_API Val* unaryOp (UnaryOpType type, Val* v1) {
121
- Val* out = newValLike (v1 );
177
+ Val* out = newOutputVal ({v1} );
122
178
new UnaryOp (type, out, v1);
123
179
return out;
124
180
}
@@ -177,7 +233,7 @@ TensorView* arithOpOverloads(
177
233
} // namespace
178
234
179
235
TORCH_CUDA_API Val* binaryOp (BinaryOpType type, Val* v1, Val* v2) {
180
- Val* out = promoteNew ( v1, v2);
236
+ Val* out = newOutputVal ({ v1, v2} );
181
237
if (is_logical_op (type)) {
182
238
if (out->getDataType ().value () != DataType::Bool)
183
239
out = newValLike (out, DataType::Bool);
@@ -322,39 +378,73 @@ TORCH_CUDA_API TensorView* andOp(TensorView* v1, TensorView* v2) {
322
378
323
379
// REDUCTION OPERATIONS
324
380
381
+ namespace {
382
+ // TODO: How do we adjust this so we can reduce to a single scalar value?
383
+ TensorView* newForReduction (TensorView* tv, std::vector<unsigned int > axes) {
384
+ auto orig_domain = TensorDomain::noReductions (tv->getRootDomain ());
385
+ std::set<unsigned int > axes_set (axes.begin (), axes.end ());
386
+
387
+ std::vector<IterDomain*> new_domain;
388
+
389
+ TORCH_INTERNAL_ASSERT (
390
+ !axes_set.empty (),
391
+ " Asked for ouput of reduction, but no reduction axis provided." );
392
+ TORCH_INTERNAL_ASSERT (
393
+ (*(axes_set.rbegin ())) < orig_domain.size (),
394
+ " Error setting up reduction, reduction axis is outside nDims. Keep in mind reductions are relative to root domains, not modified views." );
395
+
396
+ for (decltype (orig_domain.size ()) dim = 0 ; dim < orig_domain.size (); dim++) {
397
+ IterDomain* id = orig_domain[dim];
398
+
399
+ bool isReduction = false ;
400
+ if ((*axes_set.begin ()) == dim) {
401
+ isReduction = true ;
402
+ axes_set.erase (axes_set.begin ());
403
+ }
404
+
405
+ new_domain.push_back (new IterDomain (
406
+ id->start (), id->extent (), ParallelType::Serial, isReduction));
407
+ }
408
+
409
+ TensorDomain* td = new TensorDomain (new_domain);
410
+ return new TensorView (td, tv->getDataType ().value ());
411
+ }
412
+
413
+ } // namespace
414
+
325
415
TensorView* reductionOp (
326
416
BinaryOpType reduction_op_type,
327
417
const std::vector<int >& axes,
328
418
Val* init,
329
- TensorView* v1 ) {
419
+ TensorView* tv ) {
330
420
TORCH_CHECK (
331
421
init->isConstScalar (),
332
422
" Cannot create a reduction operation where the initial value is not a const scalar." );
333
423
334
424
TORCH_CHECK (
335
- v1 ->getRootDomain () == v1 ->domain (),
336
- " Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/ computeAt." );
425
+ TensorDomain::sameAs (tv ->getRootDomain (), tv ->domain ()-> domain () ),
426
+ " Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt." );
337
427
338
428
std::vector<unsigned int > uint_axes;
339
429
for (int axis : axes) {
340
430
if (axis < 0 )
341
- axis += int (v1 ->nDims ());
431
+ axis += int (tv ->nDims ());
342
432
343
433
TORCH_CHECK (
344
- axis >= 0 && (unsigned int )axis < v1 ->nDims (),
434
+ axis >= 0 && (unsigned int )axis < tv ->nDims (),
345
435
" Reduction on invalid axis, recieved: " ,
346
436
axis,
347
437
" however tensor view only has " ,
348
- v1 ->nDims (),
438
+ tv ->nDims (),
349
439
" dims." );
350
440
351
441
uint_axes.push_back ((unsigned int )axis);
352
442
}
353
443
354
- TensorView* out = v1-> newForReduction (uint_axes);
355
- if (init->getDataType ().value () != v1 ->getDataType ().value ())
356
- init = castOp (v1 ->getDataType ().value (), init);
357
- new ReductionOp (reduction_op_type, init, out, v1 );
444
+ TensorView* out = newForReduction (tv, uint_axes);
445
+ if (init->getDataType ().value () != tv ->getDataType ().value ())
446
+ init = castOp (tv ->getDataType ().value (), init);
447
+ new ReductionOp (reduction_op_type, init, out, tv );
358
448
return out;
359
449
}
360
450
@@ -377,6 +467,48 @@ TORCH_CUDA_API TensorView* sum(TensorView* v1, const std::vector<int>& axes) {
377
467
return reductionOp (BinaryOpType::Add, axes, init, v1);
378
468
}
379
469
470
+ TORCH_CUDA_API TensorView* broadcast (
471
+ TensorView* inp,
472
+ const std::vector<bool >& is_broadcast_dim) {
473
+ auto nBCastDims = is_broadcast_dim.size ();
474
+ // Validate is_broadcast_dim
475
+ unsigned int n_broadcasts = 0 ;
476
+ for (auto ent : is_broadcast_dim)
477
+ if (ent)
478
+ n_broadcasts++;
479
+ TORCH_CHECK (
480
+ nBCastDims - n_broadcasts == inp->nDims (),
481
+ " Invalid broadcast, number of false entries in is_broadcast_dim expected to be " ,
482
+ inp->nDims (),
483
+ " but received " ,
484
+ nBCastDims - n_broadcasts);
485
+
486
+ if (n_broadcasts == 0 ) {
487
+ auto identity = unaryOp (UnaryOpType::Set, inp);
488
+ TORCH_INTERNAL_ASSERT (
489
+ identity->getValType ().value () == ValType::TensorView,
490
+ " Expected identity op, but didn't get a TensorView back." );
491
+ return static_cast <TensorView*>(identity);
492
+ }
493
+
494
+ std::vector<IterDomain*> out_domain;
495
+ size_t iinp = 0 , ibdim = 0 ;
496
+ while (ibdim < is_broadcast_dim.size ()) {
497
+ if (is_broadcast_dim[ibdim]) {
498
+ out_domain.push_back (new IterDomain (
499
+ new Int (0 ), new Int (1 ), ParallelType::Serial, false , false , true ));
500
+ } else {
501
+ out_domain.push_back (inp->axis (iinp));
502
+ iinp++;
503
+ }
504
+ ibdim++;
505
+ }
506
+ TensorView* out_tensor =
507
+ new TensorView (new TensorDomain (out_domain), inp->getDataType ().value ());
508
+ new BroadcastOp (out_tensor, inp);
509
+ return out_tensor;
510
+ }
511
+
380
512
// COMPOUND OPERATIONS
381
513
382
514
// add_alpha
@@ -504,7 +636,7 @@ TORCH_CUDA_API Val* where(Val* c, Val* v1, Val* v2) {
504
636
" Condition should be of DataType Bool, not " ,
505
637
c->getDataType ().value ());
506
638
507
- Val* out = promoteNew ( v1, v2);
639
+ Val* out = newOutputVal ({ v1, v2} );
508
640
new TernaryOp (TernaryOpType::Where, out, c, v1, v2);
509
641
return out;
510
642
}
@@ -533,6 +665,8 @@ TORCH_CUDA_API TensorView* where(
533
665
return arithOpOverloads (where, v1, v2, v3);
534
666
}
535
667
668
+ // TERNARY OPERATIONS
669
+
536
670
TORCH_CUDA_API Val* threshold (Val* in, Val* thresh, Val* value) {
537
671
TORCH_CHECK (
538
672
in->getDataType ().value () == thresh->getDataType ().value () &&
@@ -544,7 +678,7 @@ TORCH_CUDA_API Val* threshold(Val* in, Val* thresh, Val* value) {
544
678
value->getValType ().value () == ValType::Scalar,
545
679
" Thresh and Value values should be Scalars" );
546
680
547
- Val* out = newValLike (in );
681
+ Val* out = newOutputVal ({in} );
548
682
549
683
new TernaryOp (TernaryOpType::Threshold, out, in, thresh, value);
550
684
return out;
@@ -565,7 +699,7 @@ TORCH_CUDA_API Val* clamp(Val* in, Val* min_val, Val* max_val) {
565
699
max_val->getValType ().value () == ValType::Scalar,
566
700
" Min and Max values should be Scalars" );
567
701
568
- Val* out = newValLike (in );
702
+ Val* out = newOutputVal ({in} );
569
703
570
704
new TernaryOp (TernaryOpType::Clamp, out, in, min_val, max_val);
571
705
return out;
0 commit comments