1
1
#include < torch/csrc/jit/tensorexpr/kernel.h>
2
+ #include < torch/csrc/jit/tensorexpr/ir_printer.h>
2
3
#include < torch/csrc/jit/tensorexpr/schedule.h>
3
4
4
5
using namespace torch ::jit;
@@ -120,12 +121,67 @@ Expr TensorExprKernel::demoteOutput(const Expr& e, const torch::jit::Value* v) {
120
121
return e;
121
122
}
122
123
124
+ static bool isOne (Expr e) {
125
+ auto const & n = e.AsNode <IntImm>();
126
+ if (!n) {
127
+ return false ;
128
+ }
129
+ return n->value () == 1 ;
130
+ }
131
+
132
+ static std::vector<Expr> broadcastShapes (
133
+ const std::vector<Expr>& a,
134
+ const std::vector<Expr>& b) {
135
+ auto at = a.rbegin ();
136
+ auto bt = b.rbegin ();
137
+ std::vector<Expr> ret;
138
+ while (at != a.rend () || bt != b.rend ()) {
139
+ if (at == a.rend ()) {
140
+ ret.push_back (*bt++);
141
+ continue ;
142
+ }
143
+ if (bt == b.rend ()) {
144
+ ret.push_back (*at++);
145
+ continue ;
146
+ }
147
+ // TODO: if neither *at nor *bt is 1, ensure they are identical
148
+ // expressions. Nb: `==` doesn't work since that simply produces a new
149
+ // Expr.
150
+ Expr dim = isOne (*at) ? *bt : *at;
151
+ ret.push_back (dim);
152
+ at++;
153
+ bt++;
154
+ }
155
+ std::reverse (ret.begin (), ret.end ());
156
+ return ret;
157
+ }
158
+
159
+ template <typename ... Args>
160
+ static std::vector<Expr> broadcastShapes (
161
+ const std::vector<Expr>& a,
162
+ const std::vector<Expr>& b,
163
+ Args... args) {
164
+ return broadcastShapes (broadcastShapes (a, b), args...);
165
+ }
166
+
167
+ std::vector<Expr> TensorExprKernel::valueShape (const torch::jit::Value* v) {
168
+ auto it = tensors_.find (v->unique ());
169
+ if (it == tensors_.end ()) {
170
+ return {1 };
171
+ }
172
+ return it->second .dims ();
173
+ }
174
+
123
175
Tensor TensorExprKernel::ComputeOneOperand (
124
176
const std::string& name,
125
177
const torch::jit::Value* v,
126
178
std::function<Expr(const Expr&)> inner_expr) {
179
+ auto const & n = v->node ();
180
+ auto const & shape = valueShape (n->inputs ()[0 ]);
127
181
return Compute (
128
- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
182
+ name,
183
+ c10::fmap<DimArg>(shape),
184
+ [this , v, inner_expr](const std::vector<Var>& axes) {
129
185
auto const & n = v->node ();
130
186
std::vector<Expr> inputs = {tensorOrConstant (n->inputs ()[0 ], axes)};
131
187
@@ -139,8 +195,13 @@ Tensor TensorExprKernel::ComputeTwoOperand(
139
195
const std::string& name,
140
196
const torch::jit::Value* v,
141
197
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
198
+ auto const & n = v->node ();
199
+ auto const & shape =
200
+ broadcastShapes (valueShape (n->inputs ()[0 ]), valueShape (n->inputs ()[1 ]));
142
201
return Compute (
143
- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
202
+ name,
203
+ c10::fmap<DimArg>(shape),
204
+ [this , v, inner_expr](const std::vector<Var>& axes) {
144
205
auto const & n = v->node ();
145
206
std::vector<Expr> inputs = {
146
207
tensorOrConstant (n->inputs ()[0 ], axes),
@@ -157,8 +218,13 @@ Tensor TensorExprKernel::ComputeTwoOperandWithAlpha(
157
218
const std::string& name,
158
219
const torch::jit::Value* v,
159
220
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
221
+ auto const & n = v->node ();
222
+ auto const & shape =
223
+ broadcastShapes (valueShape (n->inputs ()[0 ]), valueShape (n->inputs ()[1 ]));
160
224
return Compute (
161
- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
225
+ name,
226
+ c10::fmap<DimArg>(shape),
227
+ [this , v, inner_expr](const std::vector<Var>& axes) {
162
228
auto const & n = v->node ();
163
229
std::vector<Expr> inputs = {
164
230
tensorOrConstant (n->inputs ()[0 ], axes),
@@ -176,8 +242,15 @@ Tensor TensorExprKernel::ComputeThreeOperand(
176
242
const std::string& name,
177
243
const torch::jit::Value* v,
178
244
std::function<Expr(const Expr&, const Expr&, const Expr&)> inner_expr) {
245
+ auto const & n = v->node ();
246
+ auto const & shape = broadcastShapes (
247
+ valueShape (n->inputs ()[0 ]),
248
+ valueShape (n->inputs ()[1 ]),
249
+ valueShape (n->inputs ()[2 ]));
179
250
return Compute (
180
- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
251
+ name,
252
+ c10::fmap<DimArg>(shape),
253
+ [this , v, inner_expr](const std::vector<Var>& axes) {
181
254
auto const & n = v->node ();
182
255
std::vector<Expr> inputs = {
183
256
tensorOrConstant (n->inputs ()[0 ], axes),
@@ -194,9 +267,18 @@ Tensor TensorExprKernel::ComputeThreeOperand(
194
267
Tensor TensorExprKernel::ComputeFourOperand (
195
268
const std::string& name,
196
269
const torch::jit::Value* v,
197
- std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)> inner_expr) {
270
+ std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)>
271
+ inner_expr) {
272
+ auto const & n = v->node ();
273
+ auto const & shape = broadcastShapes (
274
+ valueShape (n->inputs ()[0 ]),
275
+ valueShape (n->inputs ()[1 ]),
276
+ valueShape (n->inputs ()[2 ]),
277
+ valueShape (n->inputs ()[3 ]));
198
278
return Compute (
199
- name, texprDims (v), [this , v, inner_expr](const std::vector<Var>& axes) {
279
+ name,
280
+ c10::fmap<DimArg>(shape),
281
+ [this , v, inner_expr](const std::vector<Var>& axes) {
200
282
auto const & n = v->node ();
201
283
std::vector<Expr> inputs = {
202
284
tensorOrConstant (n->inputs ()[0 ], axes),
0 commit comments