Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit cd6f5c7

Browse files
committed
GraphIR: Implement the dotty-printing of the graph.
1 parent 01c8a2e commit cd6f5c7

File tree

4 files changed

+90
-11
lines changed

4 files changed

+90
-11
lines changed

include/glow/Graph/Graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class Graph final {
104104

105105
/// Dumps the textual representation of the network.
106106
void dump();
107+
108+
/// Dump a dotty graph that depicts the module.
109+
void dumpDAG();
107110
};
108111

109112
} // namespace glow

src/glow/Graph/Graph.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include "glow/Graph/Graph.h"
44
#include "glow/Graph/Nodes.h"
55
#include "glow/IR/IR.h"
6+
#include "glow/Support/Support.h"
67

8+
#include <fstream>
79
#include <iostream>
810

911
using namespace glow;
@@ -203,3 +205,77 @@ void Graph::dump() {
203205
std::cout << n->getDebugDesc() << "\n";
204206
}
205207
}
208+
209+
/// A helper class for visiting and generating the dotty file from the graph.
210+
struct DottyPrinterPass : NodeVisitor {
211+
using edgeTy = std::pair<Node *, Node *>;
212+
std::vector<edgeTy> nodeEdges{};
213+
214+
public:
215+
// Don't revisit visited nodes.
216+
bool shouldVisit(Node *parent, Node *N) override {
217+
edgeTy e = {parent, N};
218+
return std::find(nodeEdges.begin(), nodeEdges.end(), e) == nodeEdges.end();
219+
}
220+
221+
DottyPrinterPass() = default;
222+
223+
void pre(Node *parent, Node *N) override { nodeEdges.push_back({parent, N}); }
224+
225+
std::string nodeDescr(Node *N) {
226+
if (!N) {
227+
return "";
228+
}
229+
// Print a node descriptor that looks like this:
230+
// Format: "node12" [ label = "0xf7fc43e01" shape = "record" ];
231+
std::string sb;
232+
sb += quote(std::to_string((void *)N)) + "[\n";
233+
std::string repr = escapeDottyString(N->getDebugDesc());
234+
sb += "\tlabel = " + repr + "\n";
235+
sb += "\tshape = \"record\"\n";
236+
sb += "];\n\n";
237+
return sb;
238+
}
239+
240+
std::string quote(std::string in) { return '"' + in + '"'; }
241+
std::string getDottyString() {
242+
std::string sb;
243+
244+
sb += "digraph finite_state_machine {\n\trankdir=TD;\n";
245+
246+
// Assign a unique name to each one of the nodes:
247+
for (auto &e : nodeEdges) {
248+
if (e.first) {
249+
sb += quote(std::to_string(e.second)) + " -> " +
250+
quote(std::to_string(e.first)) + ";\n";
251+
}
252+
}
253+
254+
// Assign a unique name to each one of the nodes:
255+
for (auto &e : nodeEdges) {
256+
sb += nodeDescr(e.first);
257+
sb += nodeDescr(e.second);
258+
}
259+
260+
sb += "}";
261+
return sb;
262+
}
263+
};
264+
265+
void Graph::dumpDAG() {
266+
DottyPrinterPass DP;
267+
268+
for (auto &N : nodes_) {
269+
N->visit(nullptr, &DP);
270+
}
271+
272+
std::string filename = "dotty_graph_dump_" + std::to_string(this) + ".dot";
273+
std::cout << "Writing dotty graph to: " << filename << '\n';
274+
275+
std::string rep = DP.getDottyString();
276+
277+
std::ofstream myfile;
278+
myfile.open(filename);
279+
myfile << rep;
280+
myfile.close();
281+
}

src/glow/Graph/Nodes.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ std::string Node::getDebugDesc() const { return "<node>"; }
7878

7979
std::string Variable::getDebugDesc() const {
8080
DescriptionBuilder db(getKindName());
81-
db.addParam("name ", quote(getName()))
81+
db.addParam("name", quote(getName()))
8282
.addParam("output", *getType())
8383
.addParam("init", WeightVar::getInitKindStr(initKind_))
8484
.addParam("val", val_);
@@ -87,7 +87,7 @@ std::string Variable::getDebugDesc() const {
8787

8888
std::string ConvolutionNode::getDebugDesc() const {
8989
DescriptionBuilder db(getKindName());
90-
db.addParam("name ", quote(getName()))
90+
db.addParam("name", quote(getName()))
9191
.addParam("input", *in_->getType())
9292
.addParam("output", *getType())
9393
.addParam("filter", *filter_->getType())
@@ -100,7 +100,7 @@ std::string ConvolutionNode::getDebugDesc() const {
100100
}
101101
std::string PoolNode::getDebugDesc() const {
102102
DescriptionBuilder db(getKindName());
103-
db.addParam("name ", quote(getName()))
103+
db.addParam("name", quote(getName()))
104104

105105
.addParam("input", *in_->getType())
106106
.addParam("output", *getType())
@@ -113,7 +113,7 @@ std::string PoolNode::getDebugDesc() const {
113113

114114
std::string FullyConnectedNode::getDebugDesc() const {
115115
DescriptionBuilder db(getKindName());
116-
db.addParam("name ", quote(getName()))
116+
db.addParam("name", quote(getName()))
117117
.addParam("input", *in_->getType())
118118
.addParam("output", *getType())
119119
.addParam("filter", *filter_->getType())
@@ -124,7 +124,7 @@ std::string FullyConnectedNode::getDebugDesc() const {
124124

125125
std::string LocalResponseNormalizationNode::getDebugDesc() const {
126126
DescriptionBuilder db(getKindName());
127-
db.addParam("name ", quote(getName()))
127+
db.addParam("name", quote(getName()))
128128
.addParam("input", *in_->getType())
129129
.addParam("alpha", alpha_)
130130
.addParam("beta", beta_)
@@ -135,7 +135,7 @@ std::string LocalResponseNormalizationNode::getDebugDesc() const {
135135

136136
std::string ConcatNode::getDebugDesc() const {
137137
DescriptionBuilder db(getKindName());
138-
db.addParam("name ", quote(getName()));
138+
db.addParam("name", quote(getName()));
139139

140140
for (auto input : in_) {
141141
db.addParam("input", *input->getType());
@@ -146,23 +146,23 @@ std::string ConcatNode::getDebugDesc() const {
146146

147147
std::string SoftMaxNode::getDebugDesc() const {
148148
DescriptionBuilder db(getKindName());
149-
db.addParam("name ", quote(getName()))
149+
db.addParam("name", quote(getName()))
150150
.addParam("input", *in_->getType())
151151
.addParam("selected", *selected_->getType());
152152
return db;
153153
}
154154

155155
std::string RegressionNode::getDebugDesc() const {
156156
DescriptionBuilder db(getKindName());
157-
db.addParam("name ", quote(getName()))
157+
db.addParam("name", quote(getName()))
158158
.addParam("input", *in_->getType())
159159
.addParam("expected", *expected_->getType());
160160
return db;
161161
}
162162

163163
std::string BatchNormalizationNode::getDebugDesc() const {
164164
DescriptionBuilder db(getKindName());
165-
db.addParam("name ", quote(getName()))
165+
db.addParam("name", quote(getName()))
166166
.addParam("input", *in_->getType())
167167
.addParam("beta", *bias_->getType())
168168
.addParam("gamma", *scale_->getType())
@@ -174,7 +174,7 @@ std::string BatchNormalizationNode::getDebugDesc() const {
174174

175175
std::string ArithmeticNode::getDebugDesc() const {
176176
DescriptionBuilder db(getKindName());
177-
db.addParam("name ", quote(getName()))
177+
db.addParam("name", quote(getName()))
178178
.addParam("output", *getType())
179179
.addParam("op", kind_ == ArithmeticInst::OpKind::Add ? "add" : "mul");
180180
return db;

src/glow/IR/IR.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ static const char *getDottyArrowForCC(OperandKind k) {
256256

257257
/// Dump a dotty graph that depicts the module.
258258
void Module::dumpDAG() {
259-
std::string filename = "dotty_network_dump_" + std::to_string(this) + ".dot";
259+
std::string filename = "dotty_ir_dump_" + std::to_string(this) + ".dot";
260260
std::cout << "Writing dotty graph to: " << filename << '\n';
261261

262262
std::string sb;

0 commit comments

Comments
 (0)