Skip to content

Commit 0387f04

Browse files
authored
Pass Graph to TensorExprKernel constructor (pytorch#177)
1 parent b7bfd90 commit 0387f04

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
288288
}
289289

290290
Operation createTensorExprOp(const Node* node) {
291-
auto kernel = std::make_shared<TensorExprKernel>(node);
291+
auto kernel = std::make_shared<TensorExprKernel>(*node->g(attr::Subgraph));
292292
return [kernel](Stack& stack) {
293293
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
294294
kernel->run(stack);

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -831,22 +831,21 @@ void TensorExprKernel::bindInput(const torch::jit::Value* input) {
831831
}
832832
}
833833

834-
TensorExprKernel::TensorExprKernel(const Node* node) {
834+
TensorExprKernel::TensorExprKernel(const Graph& subgraph) {
835835
KernelScope kernel_scope(kernel_arena_);
836-
auto subgraph = node->g(attr::Subgraph);
837836

838837
// Bind inputs to buffers.
839-
n_inputs_ = subgraph->inputs().size();
840-
for (auto const& input : subgraph->inputs()) {
838+
n_inputs_ = subgraph.inputs().size();
839+
for (auto const& input : subgraph.inputs()) {
841840
bindInput(input);
842841
}
843842

844843
// Bind nodes to tensor compute expressions.
845-
for (auto const& n : subgraph->nodes()) {
844+
for (auto const& n : subgraph.nodes()) {
846845
if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) {
847846
continue;
848847
} else {
849-
for (torch::jit::Value* output : n->outputs()) {
848+
for (auto const& output : n->outputs()) {
850849
if (output->hasUses()) {
851850
tensors_.emplace(output->unique(), ComputeValue(output));
852851
}
@@ -855,7 +854,7 @@ TensorExprKernel::TensorExprKernel(const Node* node) {
855854
}
856855

857856
// Move output operands from `tensors_` to `tensor_outputs_`
858-
for (const auto& output : subgraph->outputs()) {
857+
for (const auto& output : subgraph.outputs()) {
859858
CHECK(tensors_.count(output->unique())) << "Output must be a tensor";
860859
tensor_outputs_.emplace_back(tensors_.at(output->unique()));
861860
tensors_.erase(output->unique());

torch/csrc/jit/tensorexpr/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ inline std::vector<Expr> computeIndicesToBroadcast(
4242

4343
class TensorExprKernel {
4444
public:
45-
explicit TensorExprKernel(const Node* node);
45+
explicit TensorExprKernel(const Graph& subgraph);
4646

4747
void run(Stack& stack);
4848

0 commit comments

Comments
 (0)