Skip to content

Commit ea1e2ad

Browse files
authored
Broadcast based on input shapes (pytorch#178)
1 parent 0387f04 commit ea1e2ad

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/tensorexpr/kernel.h>
2+
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
23
#include <torch/csrc/jit/tensorexpr/schedule.h>
34

45
using namespace torch::jit;
@@ -120,12 +121,67 @@ Expr TensorExprKernel::demoteOutput(const Expr& e, const torch::jit::Value* v) {
120121
return e;
121122
}
122123

124+
static bool isOne(Expr e) {
125+
auto const& n = e.AsNode<IntImm>();
126+
if (!n) {
127+
return false;
128+
}
129+
return n->value() == 1;
130+
}
131+
132+
static std::vector<Expr> broadcastShapes(
133+
const std::vector<Expr>& a,
134+
const std::vector<Expr>& b) {
135+
auto at = a.rbegin();
136+
auto bt = b.rbegin();
137+
std::vector<Expr> ret;
138+
while (at != a.rend() || bt != b.rend()) {
139+
if (at == a.rend()) {
140+
ret.push_back(*bt++);
141+
continue;
142+
}
143+
if (bt == b.rend()) {
144+
ret.push_back(*at++);
145+
continue;
146+
}
147+
// TODO: if neither *at nor *bt is 1, ensure they are identical
148+
// expressions. Nb: `==` doesn't work since that simply produces a new
149+
// Expr.
150+
Expr dim = isOne(*at) ? *bt : *at;
151+
ret.push_back(dim);
152+
at++;
153+
bt++;
154+
}
155+
std::reverse(ret.begin(), ret.end());
156+
return ret;
157+
}
158+
159+
template <typename... Args>
160+
static std::vector<Expr> broadcastShapes(
161+
const std::vector<Expr>& a,
162+
const std::vector<Expr>& b,
163+
Args... args) {
164+
return broadcastShapes(broadcastShapes(a, b), args...);
165+
}
166+
167+
std::vector<Expr> TensorExprKernel::valueShape(const torch::jit::Value* v) {
168+
auto it = tensors_.find(v->unique());
169+
if (it == tensors_.end()) {
170+
return {1};
171+
}
172+
return it->second.dims();
173+
}
174+
123175
Tensor TensorExprKernel::ComputeOneOperand(
124176
const std::string& name,
125177
const torch::jit::Value* v,
126178
std::function<Expr(const Expr&)> inner_expr) {
179+
auto const& n = v->node();
180+
auto const& shape = valueShape(n->inputs()[0]);
127181
return Compute(
128-
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
182+
name,
183+
c10::fmap<DimArg>(shape),
184+
[this, v, inner_expr](const std::vector<Var>& axes) {
129185
auto const& n = v->node();
130186
std::vector<Expr> inputs = {tensorOrConstant(n->inputs()[0], axes)};
131187

@@ -139,8 +195,13 @@ Tensor TensorExprKernel::ComputeTwoOperand(
139195
const std::string& name,
140196
const torch::jit::Value* v,
141197
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
198+
auto const& n = v->node();
199+
auto const& shape =
200+
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
142201
return Compute(
143-
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
202+
name,
203+
c10::fmap<DimArg>(shape),
204+
[this, v, inner_expr](const std::vector<Var>& axes) {
144205
auto const& n = v->node();
145206
std::vector<Expr> inputs = {
146207
tensorOrConstant(n->inputs()[0], axes),
@@ -157,8 +218,13 @@ Tensor TensorExprKernel::ComputeTwoOperandWithAlpha(
157218
const std::string& name,
158219
const torch::jit::Value* v,
159220
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
221+
auto const& n = v->node();
222+
auto const& shape =
223+
broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1]));
160224
return Compute(
161-
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
225+
name,
226+
c10::fmap<DimArg>(shape),
227+
[this, v, inner_expr](const std::vector<Var>& axes) {
162228
auto const& n = v->node();
163229
std::vector<Expr> inputs = {
164230
tensorOrConstant(n->inputs()[0], axes),
@@ -176,8 +242,15 @@ Tensor TensorExprKernel::ComputeThreeOperand(
176242
const std::string& name,
177243
const torch::jit::Value* v,
178244
std::function<Expr(const Expr&, const Expr&, const Expr&)> inner_expr) {
245+
auto const& n = v->node();
246+
auto const& shape = broadcastShapes(
247+
valueShape(n->inputs()[0]),
248+
valueShape(n->inputs()[1]),
249+
valueShape(n->inputs()[2]));
179250
return Compute(
180-
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
251+
name,
252+
c10::fmap<DimArg>(shape),
253+
[this, v, inner_expr](const std::vector<Var>& axes) {
181254
auto const& n = v->node();
182255
std::vector<Expr> inputs = {
183256
tensorOrConstant(n->inputs()[0], axes),
@@ -194,9 +267,18 @@ Tensor TensorExprKernel::ComputeThreeOperand(
194267
Tensor TensorExprKernel::ComputeFourOperand(
195268
const std::string& name,
196269
const torch::jit::Value* v,
197-
std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)> inner_expr) {
270+
std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)>
271+
inner_expr) {
272+
auto const& n = v->node();
273+
auto const& shape = broadcastShapes(
274+
valueShape(n->inputs()[0]),
275+
valueShape(n->inputs()[1]),
276+
valueShape(n->inputs()[2]),
277+
valueShape(n->inputs()[3]));
198278
return Compute(
199-
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
279+
name,
280+
c10::fmap<DimArg>(shape),
281+
[this, v, inner_expr](const std::vector<Var>& axes) {
200282
auto const& n = v->node();
201283
std::vector<Expr> inputs = {
202284
tensorOrConstant(n->inputs()[0], axes),

torch/csrc/jit/tensorexpr/kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class TensorExprKernel {
8383
return t.call(indices);
8484
}
8585

86+
std::vector<Expr> valueShape(const torch::jit::Value* v);
87+
8688
void promoteInputs(std::vector<Expr>& inputs);
8789

8890
Expr demoteOutput(const Expr& e, const torch::jit::Value* v);

0 commit comments

Comments
 (0)