Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit 81e3b66

Browse files
committed
GraphIR: Add the 'name' field to the node constructor, and print the node name for all nodes.
1 parent cd6f5c7 commit 81e3b66

File tree

3 files changed

+65
-53
lines changed

3 files changed

+65
-53
lines changed

include/glow/Graph/Graph.h

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,41 +64,46 @@ class Graph final {
6464
WeightVar::InitKind initKind = WeightVar::InitKind::Broadcast,
6565
float val = 0.0);
6666

67-
ConvolutionNode *createConv(Node *input, size_t depth, size_t kernel,
68-
size_t stride, size_t pad);
67+
ConvolutionNode *createConv(llvm::StringRef name, Node *input, size_t depth,
68+
size_t kernel, size_t stride, size_t pad);
6969

70-
PoolNode *createPool(Node *input, PoolInst::OpKind kind, size_t kernel,
71-
size_t stride, size_t pad);
70+
PoolNode *createPool(llvm::StringRef name, Node *input, PoolInst::OpKind kind,
71+
size_t kernel, size_t stride, size_t pad);
7272

73-
FullyConnectedNode *createFullyConnected(Node *input, size_t outDepth);
73+
FullyConnectedNode *createFullyConnected(llvm::StringRef name, Node *input,
74+
size_t outDepth);
7475

75-
ReluNode *createRELU(Node *input);
76+
ReluNode *createRELU(llvm::StringRef name, Node *input);
7677

77-
SigmoidNode *createSigmoid(Node *input);
78+
SigmoidNode *createSigmoid(llvm::StringRef name, Node *input);
7879

79-
TanhNode *createTanh(Node *input);
80+
TanhNode *createTanh(llvm::StringRef name, Node *input);
8081

81-
SoftMaxNode *createSoftMax(Node *input, Node *selected);
82+
SoftMaxNode *createSoftMax(llvm::StringRef name, Node *input, Node *selected);
8283

83-
RegressionNode *createRegression(Node *input, Node *expected);
84+
RegressionNode *createRegression(llvm::StringRef name, Node *input,
85+
Node *expected);
8486

85-
ReshapeNode *createReshape(Node *input, llvm::ArrayRef<size_t> shape);
87+
ReshapeNode *createReshape(llvm::StringRef name, Node *input,
88+
llvm::ArrayRef<size_t> shape);
8689

87-
TransposeNode *createTranspose(Node *input, llvm::ArrayRef<unsigned> shuffle);
90+
TransposeNode *createTranspose(llvm::StringRef name, Node *input,
91+
llvm::ArrayRef<unsigned> shuffle);
8892

89-
ConcatNode *createConcat(llvm::ArrayRef<Node *> inputs, unsigned dimension);
93+
ConcatNode *createConcat(llvm::StringRef name, llvm::ArrayRef<Node *> inputs,
94+
unsigned dimension);
9095

91-
BatchNormalizationNode *createBatchNormalization(Node *input,
96+
BatchNormalizationNode *createBatchNormalization(llvm::StringRef name,
97+
Node *input,
9298
size_t channelIdx = 0,
9399
float epsilon = 1e-5,
94100
float momentum = 0.9);
95101

96-
LocalResponseNormalizationNode *
97-
createLocalResponseNormalization(Node *input, size_t halfWindowSize = 2,
98-
float alpha = 1e-4, float beta = 0.75,
99-
float k = 2.0);
102+
LocalResponseNormalizationNode *createLocalResponseNormalization(
103+
llvm::StringRef name, Node *input, size_t halfWindowSize = 2,
104+
float alpha = 1e-4, float beta = 0.75, float k = 2.0);
100105

101-
ArithmeticNode *createArithmetic(Node *LHS, Node *RHS,
106+
ArithmeticNode *createArithmetic(llvm::StringRef name, Node *LHS, Node *RHS,
102107
ArithmeticInst::OpKind op);
103108
/// @}
104109

src/glow/Graph/Graph.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ Variable *Graph::createVariable(ElemKind T, llvm::ArrayRef<size_t> dims,
3636
return createVariable(FT, name, initKind, val);
3737
}
3838

39-
ConvolutionNode *Graph::createConv(Node *input, size_t depth, size_t kernel,
40-
size_t stride, size_t pad) {
39+
ConvolutionNode *Graph::createConv(llvm::StringRef name, Node *input,
40+
size_t depth, size_t kernel, size_t stride,
41+
size_t pad) {
4142
ShapeNHWC idim = ShapeNHWC(input->dims());
4243
assert(idim.w >= kernel && idim.h >= kernel &&
4344
"buffer too small for selected stride");
@@ -59,12 +60,13 @@ ConvolutionNode *Graph::createConv(Node *input, size_t depth, size_t kernel,
5960

6061
auto OT = M_.uniqueType(ElemKind::FloatTy, outDims);
6162

62-
return addNode(new ConvolutionNode(input, OT, "Conv", filter, bias, kernel,
63+
return addNode(new ConvolutionNode(input, OT, name, filter, bias, kernel,
6364
stride, pad, depth));
6465
}
6566

66-
PoolNode *Graph::createPool(Node *input, PoolInst::OpKind kind, size_t kernel,
67-
size_t stride, size_t pad) {
67+
PoolNode *Graph::createPool(llvm::StringRef name, Node *input,
68+
PoolInst::OpKind kind, size_t kernel, size_t stride,
69+
size_t pad) {
6870
ShapeNHWC idim = ShapeNHWC(input->dims());
6971
assert(idim.w >= kernel && idim.h >= kernel &&
7072
"buffer too small for selected stride");
@@ -75,10 +77,11 @@ PoolNode *Graph::createPool(Node *input, PoolInst::OpKind kind, size_t kernel,
7577
auto OT = M_.uniqueType(ElemKind::FloatTy,
7678
{idim.n, outSz.first, outSz.second, idim.c});
7779

78-
return addNode(new PoolNode(input, OT, "pool", kind, kernel, stride, pad));
80+
return addNode(new PoolNode(input, OT, name, kind, kernel, stride, pad));
7981
}
8082

81-
FullyConnectedNode *Graph::createFullyConnected(Node *input, size_t outDepth) {
83+
FullyConnectedNode *Graph::createFullyConnected(llvm::StringRef name,
84+
Node *input, size_t outDepth) {
8285
TypeRef T = input->getType();
8386
auto idim = flattenCdr(input->dims());
8487

@@ -91,35 +94,37 @@ FullyConnectedNode *Graph::createFullyConnected(Node *input, size_t outDepth) {
9194
WeightVar::InitKind::Xavier, 0.1);
9295

9396
auto OT = M_.uniqueType(T->getElementType(), {idim.first, outDepth});
94-
return addNode(
95-
new FullyConnectedNode(input, OT, "Fullyconnected", W, B, outDepth));
97+
return addNode(new FullyConnectedNode(input, OT, name, W, B, outDepth));
9698
}
9799

98-
ReluNode *Graph::createRELU(Node *input) {
99-
return addNode(new ReluNode(input, "relu"));
100+
ReluNode *Graph::createRELU(llvm::StringRef name, Node *input) {
101+
return addNode(new ReluNode(input, name));
100102
}
101103

102-
SigmoidNode *Graph::createSigmoid(Node *input) {
103-
return addNode(new SigmoidNode(input, "Sigmoid"));
104+
SigmoidNode *Graph::createSigmoid(llvm::StringRef name, Node *input) {
105+
return addNode(new SigmoidNode(input, name));
104106
}
105107

106-
TanhNode *Graph::createTanh(Node *input) {
107-
return addNode(new TanhNode(input, "Tanh"));
108+
TanhNode *Graph::createTanh(llvm::StringRef name, Node *input) {
109+
return addNode(new TanhNode(input, name));
108110
}
109111

110-
SoftMaxNode *Graph::createSoftMax(Node *input, Node *selected) {
111-
return addNode(new SoftMaxNode(input, "SoftMax", selected));
112+
SoftMaxNode *Graph::createSoftMax(llvm::StringRef name, Node *input,
113+
Node *selected) {
114+
return addNode(new SoftMaxNode(input, name, selected));
112115
}
113116

114-
RegressionNode *Graph::createRegression(Node *input, Node *expected) {
115-
return addNode(new RegressionNode(input, "Regression", expected));
117+
RegressionNode *Graph::createRegression(llvm::StringRef name, Node *input,
118+
Node *expected) {
119+
return addNode(new RegressionNode(input, name, expected));
116120
}
117121

118-
ReshapeNode *Graph::createReshape(Node *input, llvm::ArrayRef<size_t> shape) {
119-
return addNode(new ReshapeNode(input, "Reshape", shape));
122+
ReshapeNode *Graph::createReshape(llvm::StringRef name, Node *input,
123+
llvm::ArrayRef<size_t> shape) {
124+
return addNode(new ReshapeNode(input, name, shape));
120125
}
121126

122-
TransposeNode *Graph::createTranspose(Node *input,
127+
TransposeNode *Graph::createTranspose(llvm::StringRef name, Node *input,
123128
llvm::ArrayRef<unsigned> shuffle) {
124129
std::vector<size_t> shape;
125130
auto dims = input->dims();
@@ -128,10 +133,11 @@ TransposeNode *Graph::createTranspose(Node *input,
128133
}
129134

130135
auto NT = M_.uniqueType(input->getElementType(), shape);
131-
return addNode(new TransposeNode(input, NT, "Transpose", shuffle));
136+
return addNode(new TransposeNode(input, NT, name, shuffle));
132137
}
133138

134-
ConcatNode *Graph::createConcat(llvm::ArrayRef<Node *> inputs,
139+
ConcatNode *Graph::createConcat(llvm::StringRef name,
140+
llvm::ArrayRef<Node *> inputs,
135141
unsigned dimension) {
136142
auto inDim = inputs[0]->dims();
137143

@@ -146,10 +152,11 @@ ConcatNode *Graph::createConcat(llvm::ArrayRef<Node *> inputs,
146152
shape[dimension] *= inputs.size();
147153

148154
auto NT = M_.uniqueType(inputs[0]->getElementType(), shape);
149-
return addNode(new ConcatNode(inputs, NT, "Concat", dimension));
155+
return addNode(new ConcatNode(inputs, NT, name, dimension));
150156
}
151157

152-
BatchNormalizationNode *Graph::createBatchNormalization(Node *input,
158+
BatchNormalizationNode *Graph::createBatchNormalization(llvm::StringRef name,
159+
Node *input,
153160
size_t channelIdx,
154161
float epsilon,
155162
float momentum) {
@@ -167,14 +174,14 @@ BatchNormalizationNode *Graph::createBatchNormalization(Node *input,
167174
auto *variance = createVariable(ElemKind::FloatTy, {channels}, "variance",
168175
WeightVar::InitKind::Broadcast, 0.0);
169176

170-
return addNode(new BatchNormalizationNode(input, "Norm", gamma, beta, mean,
171-
variance, channelIdx, epsilon,
172-
momentum));
177+
return addNode(new BatchNormalizationNode(
178+
input, name, gamma, beta, mean, variance, channelIdx, epsilon, momentum));
173179
}
174180

175181
LocalResponseNormalizationNode *
176-
Graph::createLocalResponseNormalization(Node *input, size_t halfWindowSize,
177-
float alpha, float beta, float k) {
182+
Graph::createLocalResponseNormalization(llvm::StringRef name, Node *input,
183+
size_t halfWindowSize, float alpha,
184+
float beta, float k) {
178185
auto Ty = input->getType();
179186
auto *scale =
180187
createVariable(Ty, "scale", WeightVar::InitKind::Broadcast, 0.0);
@@ -184,8 +191,8 @@ Graph::createLocalResponseNormalization(Node *input, size_t halfWindowSize,
184191
input, "LRN", scale, halfWindowSize, alpha, beta, k));
185192
}
186193

187-
ArithmeticNode *Graph::createArithmetic(Node *LHS, Node *RHS,
188-
ArithmeticInst::OpKind op) {
194+
ArithmeticNode *Graph::createArithmetic(llvm::StringRef name, Node *LHS,
195+
Node *RHS, ArithmeticInst::OpKind op) {
189196
assert(LHS->dims() == RHS->dims() && "Invalid operand shapes");
190197
// The output tensor is of the same shape as the input tensor.
191198
return addNode(new ArithmeticNode("Arithmetic", LHS, RHS, op));

src/glow/Graph/Nodes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ std::string ArithmeticNode::getDebugDesc() const {
183183
#define DEFINE_CLASS_REPR(CLASS_NAME) \
184184
std::string CLASS_NAME::getDebugDesc() const { \
185185
DescriptionBuilder db(getKindName()); \
186-
db.addParam("input", *in_->getType()); \
186+
db.addParam("name", quote(getName())).addParam("input", *in_->getType()); \
187187
return db; \
188188
}
189189

0 commit comments

Comments
 (0)