Skip to content

Merge from upstream #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ static void sigmoid_kernel(Tensor& result, const Tensor& self) {

#define IMPLEMENT_FLOAT_KERNEL(dispatchtypes, op) \
static void op##_kernel(Tensor& result, const Tensor& self) { \
checkBackend(#op, {result}, kCPU); \
AT_DISPATCH_##dispatchtypes##_TYPES(self.type(), #op, [&] { \
if (self.is_contiguous() && result.is_contiguous()) { \
vml::v##op( \
Expand Down
18 changes: 18 additions & 0 deletions caffe2/core/nomnigraph/Representations/NeuralNet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ void coalesceInsertedDataDependencies(repr::NNModule* m) {
}
}

bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef) {
auto nodeOutputs = nn::getOutputs(nodeRef);
NOM_REQUIRE_OR_RET_FALSE(nodeOutputs.size() == 1);
auto nodeConsumers = nn::getConsumers(nodeOutputs.front());
return nodeConsumers.size() == 1;
}

NNNodeMatchCriteria matchAnyNode() {
return [](NNGraph::NodeRef /* unused */) { return true; };
}

NNSubtree operatorTree(
const NNNodeMatchCriteria& root,
const std::vector<NNSubtree>& childrenCriteria,
int count) {
return NNSubtree(matchAnyNode(), {NNSubtree(root, childrenCriteria)}, count);
}

} // namespace nn

} // namespace repr
Expand Down
4 changes: 4 additions & 0 deletions caffe2/core/nomnigraph/include/nomnigraph/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,10 @@ class Graph {
return result;
}

const size_t getNodesCount() const {
return (size_t)nodes_.size();
}

const std::vector<EdgeRef> getMutableEdges() {
std::vector<EdgeRef> result;
for (auto& e : edges_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "nomnigraph/Representations/ControlFlow.h"
#include "nomnigraph/Support/Casting.h"
#include "nomnigraph/Support/Pointer.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"

#include <string>
#include <type_traits>
Expand Down Expand Up @@ -420,6 +421,58 @@ void coalesceInsertedDataDependencies(repr::NNModule* m);
template <NNGraph* G>
struct NodeHelper {};

using NNNodeMatchCriteria = std::function<bool(NNGraph::NodeRef)>;
using NNSubtree = nom::matcher::SubtreeMatchCriteria<NNNodeMatchCriteria>;

bool hasSingleOutputAndConsumer(NNGraph::NodeRef nodeRef);

template <typename NodeType>
NNNodeMatchCriteria matchNodeTypeWithPredicate(
const std::function<bool(NNGraph::NodeRef, const NodeType&)> predicate,
bool expectedSingleOutputAndConsumer = false) {
return
[&predicate, expectedSingleOutputAndConsumer](NNGraph::NodeRef nodeRef) {
NOM_REQUIRE_OR_RET_FALSE(is<NodeType>(nodeRef));
if (expectedSingleOutputAndConsumer) {
NOM_REQUIRE_OR_RET_FALSE(hasSingleOutputAndConsumer(nodeRef));
}
NodeType* node = get<NodeType>(nodeRef);
return predicate(nodeRef, *node);
};
};

template <typename NodeType>
NNNodeMatchCriteria matchNodeType(
bool expectedSingleOutputAndConsumer = false) {
return [expectedSingleOutputAndConsumer](NNGraph::NodeRef nodeRef) {
if (expectedSingleOutputAndConsumer) {
NOM_REQUIRE_OR_RET_FALSE(hasSingleOutputAndConsumer(nodeRef));
}
return is<NodeType>(nodeRef);
};
}

NNNodeMatchCriteria matchAnyNode();

struct NNNodeMatch {
static bool isMatch(
const NNGraph::NodeRef& node,
const NNNodeMatchCriteria& criteria) {
return criteria(node);
}
};

using NNSubgraphMatcher =
nom::matcher::SubgraphMatcher<NNGraph, NNNodeMatchCriteria, NNNodeMatch>;

// This helper method makes it easy to create matching criteria in NNGraph.
// For example, operatorTree(opMatch, ...) will refer to a tree like this:
// ... -> opMatch -> opMatch_Output
NNSubtree operatorTree(
const NNNodeMatchCriteria& root,
const std::vector<NNSubtree>& childrenCriteria = {},
int count = 1);

} // namespace nn

} // namespace repr
Expand Down
1 change: 1 addition & 0 deletions caffe2/core/nomnigraph/include/nomnigraph/Support/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define NOM_REQUIRE_OR_BREAK(_cond) NOM_REQUIRE_OR_(_cond, break)
#define NOM_REQUIRE_OR_RET_NULL(_cond) NOM_REQUIRE_OR_(_cond, return nullptr)
#define NOM_REQUIRE_OR_RET_FALSE(_cond) NOM_REQUIRE_OR_(_cond, return false)
#define NOM_REQUIRE_OR_RET_TRUE(_cond) NOM_REQUIRE_OR_(_cond, return true)
#define NOM_REQUIRE_OR_RET(_cond) NOM_REQUIRE_OR_(_cond, return )

// Implements accessors for a generic type T. If the type is not
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef NOM_TRANFORMATIONS_SUBGRAPH_MATCHER_H
#define NOM_TRANFORMATIONS_SUBGRAPH_MATCHER_H

#include <functional>
#include <vector>

namespace nom {

namespace matcher {
Expand All @@ -10,23 +13,37 @@ namespace matcher {
* - Node matching criteria for the subtree's root.
* - Children subtree matching criteria
* - A count, which means we may want more than one of this subtree. The count
* can be unlimited. The count is only used when we match children of a
* subtree root, not matching the subtree itself.
* can be unlimited. The count is only used when we match children of a subtree
* root, not matching the subtree itself.
* - If nonTerminal flag is set, it means we only match the root and do not
* care about the children.
*/
template <typename NodeMatchCriteria>
class SubtreeMatchCriteria {
public:
static const int kStarCount = -1;
SubtreeMatchCriteria(
const NodeMatchCriteria& root,
const std::vector<SubtreeMatchCriteria>& children,
int count)
: root_(root), children_(children), count_(count){};
const std::vector<SubtreeMatchCriteria>& children = {},
int count = 1,
bool nonTerminal = false)
: root_(root),
children_(children),
count_(count),
nonTerminal_(nonTerminal){};

// Non terminal
static SubtreeMatchCriteria<NodeMatchCriteria> nonTerminal(
const NodeMatchCriteria& root,
int count = 1) {
return SubtreeMatchCriteria(root, {}, count, true);
}

private:
NodeMatchCriteria root_;
std::vector<SubtreeMatchCriteria> children_;
int count_;
bool nonTerminal_;

template <typename, typename, typename>
friend class SubgraphMatcher;
Expand Down Expand Up @@ -58,6 +75,11 @@ struct SubgraphMatcher {
if (!isNodeMatch(root, criteria.root_)) {
return false;
}
if (criteria.nonTerminal_) {
// This is sufficient to be a match if this criteria specifies a non
// terminal node.
return true;
}
auto& edges =
invertGraphTraversal ? root->getInEdges() : root->getOutEdges();

Expand Down Expand Up @@ -87,9 +109,9 @@ struct SubgraphMatcher {
(isStarCount || countMatch < expectedCount);
currentEdgeIdx++) {
auto edge = edges[currentEdgeIdx];
auto nextNode = invertGraphTraversal ? edge->tail() : edge->head();
auto child = invertGraphTraversal ? edge->tail() : edge->head();

if (!isSubtreeMatch(nextNode, childrenCriteria, invertGraphTraversal)) {
if (!isSubtreeMatch(child, childrenCriteria, invertGraphTraversal)) {
if (!isStarCount) {
// If the current criteria isn't a * pattern, this indicates a
// failure.
Expand Down Expand Up @@ -150,23 +172,6 @@ struct SubgraphMatcher {
}
};

// Convenient methods to create subtree matching criteria.
template <typename NodeMatchCriteria>
SubtreeMatchCriteria<NodeMatchCriteria> tree(
const NodeMatchCriteria& root,
const std::vector<SubtreeMatchCriteria<NodeMatchCriteria>>& children = {},
int count = 1) {
return SubtreeMatchCriteria<NodeMatchCriteria>(root, children, count);
}

template <typename NodeMatchCriteria>
SubtreeMatchCriteria<NodeMatchCriteria> treeStar(
const NodeMatchCriteria& root,
const std::vector<SubtreeMatchCriteria<NodeMatchCriteria>>& children = {}) {
return tree(
root, children, SubtreeMatchCriteria<NodeMatchCriteria>::kStarCount);
}

} // namespace matcher

} // namespace nom
Expand Down
92 changes: 92 additions & 0 deletions caffe2/core/nomnigraph/tests/neural_net_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include <algorithm>

#include "test_util.h"

#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Support/Pointer.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"

#include <gtest/gtest.h>

using namespace nom;
using namespace nom::repr;
using namespace nom::repr::nn;

// Test for the NNGraph subgraph matching APIs.
TEST(NeuralNetGraph, ReplaceGraph) {
NNGraph graph;

auto input1 = graph.createNode(util::make_unique<Tensor>("input1"));
auto input2 = graph.createNode(util::make_unique<Tensor>("input2"));
auto sum = graph.createNode(util::make_unique<Sum>());
auto sumOutput = graph.createNode(util::make_unique<Tensor>("sumOutput"));
auto relu = graph.createNode(util::make_unique<Relu>());
auto reluOutput = graph.createNode(util::make_unique<Tensor>("reluOutput"));

graph.createEdge(input1, sum);
graph.createEdge(input2, sum);
graph.createEdge(sum, sumOutput);
graph.createEdge(sumOutput, relu);
graph.createEdge(relu, reluOutput);

/* input1 input2
\ /
\ /
sum
|
|
sumOutput
|
relu
|
reluOutput
*/

// clang-format off
auto pattern = NNSubtree(
matchNodeType<Relu>(), {
operatorTree(
matchNodeType<Sum>(), {
NNSubtree::nonTerminal(matchNodeType<Tensor>(), 2)
}),
});
// clang-format on

EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(sum, pattern));
EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(reluOutput, pattern));
EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(input1, pattern));

EXPECT_TRUE(NNSubgraphMatcher::isSubtreeMatch(relu, pattern));

NNSubgraphMatcher::replaceSubtree(
graph, pattern, [](NNGraph& g, NNGraph::NodeRef relu) {
auto sumOutput = getInputs(relu)[0];
auto sum = getProducer(sumOutput);

auto fusedNode = g.createNode(util::make_unique<SumRelu>());
g.deleteNode(sumOutput);
g.replaceNode(relu, fusedNode);
g.replaceNode(sum, fusedNode);

g.deleteNode(sum);
g.deleteNode(relu);

return true;
});

/*
Fused graph:

input1 input2
\ /
\ /
sumRelu
|
|
output
*/
EXPECT_EQ(graph.getNodesCount(), 4);
auto fusedNode = getProducer(reluOutput);
EXPECT_TRUE(is<SumRelu>(fusedNode));
EXPECT_EQ(getInputs(fusedNode).size(), 2);
}
Loading