Skip to content

Commit 3685608

Browse files
[Graph class] add State to validate order of steps
1 parent b0c6bc6 commit 3685608

File tree

6 files changed

+36
-0
lines changed

6 files changed

+36
-0
lines changed

include/glow/Graph/Graph.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ using UnsignedArrayRef = llvm::ArrayRef<size_t>;
2121

2222
/// Represents the compute graph.
2323
class Graph final : public Named {
24+
public:
25+
enum class State {
26+
Created,
27+
Differentiated,
28+
Lowered,
29+
Optimized,
30+
IRGenerated,
31+
};
32+
33+
private:
34+
/// A current state of the graph.
35+
State state_{State::Created};
2436
/// A uniqued list of types in the module. Types in this list can be equated
2537
/// by comparing their addresses.
2638
TypesList types_{};
@@ -49,13 +61,17 @@ class Graph final : public Named {
4961

5062
/// Inserts the node \p N to the list of nodes, and returns the inserted node.
5163
template <class NodeTy> NodeTy *addNode(NodeTy *N) {
64+
assert(state_ < State::IRGenerated &&
65+
"Trying to add Node when IR is already generated.");
5266
uniqueNames(N);
5367
nodes_.push_back(N);
5468
return N;
5569
}
5670

5771
/// Inserts the variable \p V to the list of variables.
5872
Variable *addVar(Variable *V) {
73+
assert(state_ < State::IRGenerated &&
74+
"Trying to add Variable when IR is already generated.");
5975
uniqueNames(V);
6076
vars_.push_back(V);
6177
return V;
@@ -199,6 +215,13 @@ class Graph final : public Named {
199215
/// Returns nullptr if there is no gradient variable
200216
/// related to this variable.
201217
Variable *getGradientVariable(Variable *V);
218+
219+
/// Resets current state to Created.
220+
void resetState();
221+
222+
/// Verifies that current state of the graph is not later then \p s
223+
/// and assigns current state to be \p s.
224+
void advanceState(State s);
202225
};
203226

204227
struct TrainingConfig;

src/glow/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ void ExecutionEngine::compile(CompilationMode mode) {
8888
// Wipe out the module and start a new compilation process.
8989
M_->clear();
9090
IP_->clear();
91+
G_->resetState();
9192

9293
if (mode != CompilationMode::Infer) {
9394
generateGradientNodes(*G_, getConfig(), mode);

src/glow/Graph/Graph.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ std::string Graph::uniqueName(llvm::StringRef name) {
9191
void Graph::uniqueNames(Node *N) { N->setName(uniqueName(N->getName())); }
9292

9393
void Graph::addGradientVariable(Variable *V, Variable *GradV) {
94+
advanceState(State::Differentiated);
9495
grads_.push_back({V, GradV});
9596
}
9697

@@ -650,3 +651,10 @@ void Graph::verify() const {
650651
llvm_unreachable("Multiple nodes with the same name");
651652
}
652653
}
654+
655+
void Graph::resetState() { state_ = State::Created; }
656+
657+
void Graph::advanceState(State s) {
658+
assert(state_ <= s && "Wrong order of actions with a graph.");
659+
state_ = s;
660+
}

src/glow/IR/IRGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ struct IRGenVisitor : NodeWalker {
530530
} // namespace
531531

532532
void Module::generateIR(CompilationMode mode) {
533+
G_->advanceState(Graph::State::IRGenerated);
533534
G_->verify();
534535
IRGenVisitor irgen(this);
535536

src/glow/Optimizer/GraphOptimizer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ static void OptimizeSliceOfSplat(Graph &G) {
546546
}
547547

548548
void glow::optimize(Graph &G, CompilationMode mode) {
549+
G.advanceState(Graph::State::Optimized);
550+
549551
// Sink transpose operations in an attempt to cancel them out.
550552
sinkCode(G);
551553

src/glow/Optimizer/Lower.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ void lowerReluNode(Graph &graph, ReluNode &R) {
223223
}
224224

225225
void glow::lower(Graph &G, CompilationMode mode) {
226+
G.advanceState(Graph::State::Lowered);
226227
auto &nodes = G.getNodes();
227228

228229
for (auto const &node : nodes) {

0 commit comments

Comments
 (0)