Skip to content

Commit 74f88b3

Browse files
nrsatishbeicy
authored andcommitted
Add data structures for compute and communication time; add function to fill in compute and memory bandwidth bound times for ops
1 parent 4d751bd commit 74f88b3

File tree

4 files changed

+244
-4
lines changed

4 files changed

+244
-4
lines changed

include/glow/Partitioner/Partitioner.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace glow {
2525
using namespace runtime;
2626

2727
using MemUsageMapTy = std::unordered_map<Node *, size_t>;
28+
using ComputeTimeMapTy = std::unordered_map<Node *, float>;
2829
using NodesSetTy = std::set<Node *>;
2930
using PartitionCostMapTy = llvm::DenseMap<Function *, GraphMemInfo>;
3031

@@ -97,6 +98,9 @@ class Partitioner {
9798
/// The map of each operator and the corresponding memory size.
9899
MemUsageMapTy memUsage_;
99100

101+
/// The map of each operator and the compute runtime.
102+
ComputeTimeMapTy computeTime_;
103+
100104
/// Get the representative function (the one with the largest input) and
101105
/// update the memSize.
102106
static Function *selectRepFunc(Module *parent, size_t &memSize);
@@ -105,6 +109,9 @@ class Partitioner {
105109
/// function.
106110
void initOpMemUsage();
107111

112+
/// Inititalize the minimal compute time for each op in the function.
113+
void initOpComputeTime();
114+
108115
/// Combine the partitions if necessary : if all outside uses of the nodes in
109116
/// /// partition1 is in partition2, and the sum of memory consumption of
110117
/// partition1 and partition2 is less than availableMemory, combine partition1
@@ -140,6 +147,9 @@ class Partitioner {
140147

141148
/// Decompose each function in a module and return a list of DAGNodes.
142149
DAGNodeList &Partition();
150+
151+
/// Get function for computeTime_
152+
ComputeTimeMapTy getComputeTime() const { return computeTime_; }
143153
};
144154
} // namespace glow
145155
#endif // GLOW_PARTITIONER_PARTITIONER_H

include/glow/Runtime/RuntimeTypes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ using ResultCBTy = std::function<void(
4848
struct DeviceInfo {
4949
/// Available memory on device in bytes.
5050
size_t availableMemory;
51+
/// Available SRAM capacity in bytes.
52+
size_t sramCapacity;
53+
/// Peak compute on device in ops/second. Assumes all ops are in int8.
54+
/// TODO: distinguish between data types with different peak flops.
55+
float peakCompute;
56+
/// Peak memory bandwidth from DRAM on device in bytes/second.
57+
float peakDramBw;
58+
/// Peak memory bandwidth from SRAM on device in bytes/second.
59+
float peakSramBw;
60+
/// Peak ingress/egress PCI-E bandwidth from device in bytes/second.
61+
float peakPCIeBw;
5162
};
5263

5364
/// Individual Node in the DAG for a given network. This contains all the

lib/Partitioner/Partitioner.cpp

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,126 @@ void Partitioner::initOpMemUsage() {
8181
}
8282
}
8383

84+
/// Get the minimal compute time for each op in the function.
85+
void Partitioner::initOpComputeTime() {
86+
computeTime_.clear();
87+
88+
// This code assumes all ops are BW limited from SRAM; except
89+
// if the input does not fit in SRAM -- then it is DRAM BW limited
90+
float peakDramBw = deviceInfo_[0].peakDramBw;
91+
float peakSramBw = deviceInfo_[0].peakSramBw;
92+
size_t sramCapacity = deviceInfo_[0].sramCapacity;
93+
float peakCompute = deviceInfo_[0].peakCompute;
94+
95+
for (auto &node : F_->getNodes()) {
96+
/// compute memory side bytes for inputs from DRAM, SRAM.
97+
/// TODO: think about whether this is better off computed inside a Node.
98+
99+
int n = node.getNumInputs();
100+
uint64_t sizeDram = 0;
101+
uint64_t sizeSram = 0;
102+
if (node.getKind() == Kinded::Kind::SaveNodeKind) {
103+
computeTime_[&node] = 0.0f;
104+
continue;
105+
}
106+
107+
/// The memory bytes for embedding table lookups is data dependent,
108+
/// so it needs to be calculated as per the number of indices accessed.
109+
if (node.getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind) {
110+
auto *SLWSN = llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node);
111+
/// compute how many entries of the embedding table we look up
112+
auto numLookups = SLWSN->getIndices().getNode()->dims(0).front();
113+
/// compute how many bytes we read per lookup
114+
auto tableSize = SLWSN->getData().getNode()->getType(0)->getSizeInBytes();
115+
auto numRows = SLWSN->getData().getNode()->dims(0).front();
116+
auto sizePerLookup = tableSize / numRows;
117+
/// compute total bytes read
118+
uint64_t sizeInput = numLookups * sizePerLookup;
119+
120+
/// does the table fit in SRAM or DRAM
121+
if (tableSize > sramCapacity) {
122+
sizeDram += sizeInput;
123+
} else {
124+
sizeSram += sizeInput;
125+
}
126+
127+
/// we also read the indices, weights and lengths arrays
128+
sizeSram += SLWSN->getIndices().getNode()->getType(0)->getSizeInBytes();
129+
sizeSram += SLWSN->getWeights().getNode()->getType(0)->getSizeInBytes();
130+
sizeSram += SLWSN->getLengths().getNode()->getType(0)->getSizeInBytes();
131+
} else {
132+
/// for all other ops, iterate through all inputs and get size in bytes
133+
for (int i = 0; i < n; i++) {
134+
auto ty = node.getNthInput(i).getNode()->getType(0);
135+
uint64_t sizeInput = ty->getSizeInBytes();
136+
if (sizeInput > sramCapacity) {
137+
sizeDram += sizeInput;
138+
} else {
139+
sizeSram += sizeInput;
140+
}
141+
}
142+
}
143+
144+
// Repeat for outputs
145+
if (node.getNumResults() > 0) {
146+
auto myty = node.getType(0);
147+
uint64_t sizeOutput = myty->getSizeInBytes();
148+
if (sizeOutput > sramCapacity) {
149+
sizeDram += sizeOutput;
150+
} else {
151+
sizeSram += sizeOutput;
152+
}
153+
}
154+
155+
/// Calculate compute ops. Currently only computed for Matmul, Conv, FC
156+
/// TODO: think about whether this is better off computed inside a Node.
157+
uint64_t totalOps = 0;
158+
switch (node.getKind()) {
159+
case Kinded::Kind::MatMulNodeKind: {
160+
auto *MMN = llvm::dyn_cast<MatMulNode>(&node);
161+
auto lhsDims = MMN->getLHS().dims();
162+
auto rhsDims = MMN->getRHS().dims();
163+
totalOps = 2 * lhsDims[0] * lhsDims[1] * rhsDims[1];
164+
break;
165+
}
166+
case Kinded::Kind::FullyConnectedNodeKind: {
167+
auto *FCN = llvm::dyn_cast<FullyConnectedNode>(&node);
168+
auto inputDims = FCN->getInput().dims();
169+
auto wtDims = FCN->getWeights().dims();
170+
totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[1];
171+
break;
172+
}
173+
case Kinded::Kind::ConvolutionNodeKind: {
174+
auto *CN = llvm::dyn_cast<ConvolutionNode>(&node);
175+
auto resultDims = CN->getResult().dims();
176+
// Get the product of batch, output height, output dims, output channels
177+
totalOps = resultDims[0];
178+
for (size_t i = 1, e = resultDims.size(); i < e; i++) {
179+
totalOps *= resultDims[i];
180+
}
181+
// Multiply in kernel height, kernel width
182+
auto kernelDims = CN->getKernels();
183+
totalOps *= kernelDims[0] * kernelDims[1];
184+
// Multiply in input channels/groups
185+
auto inputChannels = CN->getInput().dims()[1];
186+
auto nGroups = CN->getGroup();
187+
totalOps *= (inputChannels * 1.0 / nGroups);
188+
break;
189+
}
190+
default:
191+
break;
192+
}
193+
194+
/// Compute compute roofline as max of flops, DRAM, SRAM BW
195+
/// See https://bit.ly/2UdJ3mz
196+
/// Add epsilons to prevent seg faults on unitialized peak values
197+
computeTime_[&node] =
198+
std::max(totalOps * 1.0f / std::max(peakCompute, 1e-6f),
199+
std::max(sizeDram * 1.0f / std::max(peakDramBw, 1e-6f),
200+
sizeSram * 1.0f / std::max(peakSramBw, 1e-6f)));
201+
}
202+
}
203+
84204
// Combine the partitions if necessary : if all outside uses of the nodes in
85205
// partition1 is in partition2, and the sum of memory consumption of partition1
86206
// and partition2 is less than availableMemory, combine partition1 and
@@ -403,7 +523,6 @@ DAGNodeList &Partitioner::Partition() {
403523

404524
// Find the representive function for running partitioning algrithm.
405525
F_ = selectRepFunc(module_, memSize_);
406-
407526
size_t availMem = deviceInfo_[0].availableMemory;
408527

409528
if (memSize_ < availMem) {
@@ -427,9 +546,8 @@ DAGNodeList &Partitioner::Partition() {
427546
// Prepare 1: Get the min memory usage for each op.
428547
initOpMemUsage();
429548

430-
// Prepare 2: TODO: get the minimal comunication cost for any 2 ops (i.e. the
431-
// output data size) Will calculate it on the fly. -- Will double check which
432-
// way is better.
549+
// Prepare 2: Get the roofline memory bandwidth estimate for each op.
550+
initOpComputeTime();
433551

434552
// Partition
435553
// Use BFS to do the initial partitioning. Starting from the final node, BFS

tests/unittests/PartitionerTest.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,104 @@ TEST_F(PartitionerTest, Basic2) {
208208
EXPECT_TRUE(ref.isEqual(test));
209209
}
210210
}
211+
212+
/// This one tests the roofline computed with compute, memory and communication
213+
/// costs
214+
TEST_F(PartitionerTest, Basic1Roofline) {
215+
auto *input =
216+
mod_.createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false);
217+
auto *w1 = mod_.createConstant(ElemKind::FloatTy, {32, 16}, "w1");
218+
auto *b1 = mod_.createConstant(ElemKind::FloatTy, {16}, "b1");
219+
ctx_.allocate(input);
220+
w1->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
221+
b1->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
222+
223+
// Initial FC.
224+
Node *I = F_->createFullyConnected("initial_fc", input, w1, b1);
225+
I = F_->createSigmoid("initial_sigmoid", I);
226+
227+
// Left branch.
228+
auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2");
229+
auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2");
230+
w2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
231+
b2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
232+
Node *L = F_->createFullyConnected("left_fc1", I, w2, b2);
233+
L = F_->createSigmoid("left_sigmoid1", L);
234+
auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3");
235+
auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3");
236+
w3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
237+
b3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
238+
L = F_->createFullyConnected("left_fc2", L, w3, b3);
239+
L = F_->createSigmoid("left_sigmoid2", L);
240+
241+
// Right branch.
242+
auto *w4 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w4");
243+
auto *b4 = mod_.createConstant(ElemKind::FloatTy, {16}, "b4");
244+
w4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
245+
b4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
246+
Node *R = F_->createFullyConnected("right_fc1", I, w4, b4);
247+
R = F_->createSigmoid("right_sigmoid1", R);
248+
auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5");
249+
auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5");
250+
w5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
251+
b5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
252+
R = F_->createFullyConnected("right_fc2", R, w5, b5);
253+
R = F_->createSigmoid("right_sigmoid2", R);
254+
255+
// Join branches.
256+
auto *mul = F_->createMul("mul", L, R);
257+
auto *save = F_->createSave("ret", mul);
258+
auto &res = *ctx_.allocate(save->getPlaceholder());
259+
260+
// Infer using the un-partitioned graph.
261+
Tensor in(ElemKind::FloatTy, {1, 32});
262+
ExecutionEngine EE;
263+
264+
EE.compile(CompilationMode::Infer, F_);
265+
updateInputPlaceholders(ctx_, {input}, {&in});
266+
EE.run(ctx_);
267+
Tensor ref = res.clone();
268+
269+
std::unordered_map<Node *, std::string> nodeNamesMap;
270+
for (auto &node : F_->getNodes()) {
271+
nodeNamesMap[&node] = node.getName();
272+
}
273+
274+
std::vector<DeviceInfo> devices = {{3072, 100, 10, 0.1, 1, 0.05},
275+
{3072, 100, 10, 0.1, 1, 0.05},
276+
{3072, 100, 10, 0.1, 1, 0.05}};
277+
Partitioner myPartitioner(&mod_, devices);
278+
279+
DAGNodeList myList = std::move(myPartitioner.Partition());
280+
281+
// check compute costs
282+
std::unordered_map<std::string, float> expectedComputeTime{
283+
{"initial_sigmoid", 128},
284+
{"left_sigmoid2", 64},
285+
{"fc_add_bias3", 192},
286+
{"right_sigmoid1", 128},
287+
{"mul", 96},
288+
{"fc_add_bias2", 96},
289+
{"ret", 0},
290+
{"fc_dot", 21760},
291+
{"left_sigmoid1", 128},
292+
{"fc_add_bias", 192},
293+
{"fc_dot1", 10240},
294+
{"right_sigmoid2", 64},
295+
{"fc_add_bias1", 192},
296+
{"fc_dot2", 5120},
297+
{"fc_dot3", 10240},
298+
{"fc_dot4", 5120},
299+
{"fc_add_bias4", 96},
300+
};
301+
ASSERT_EQ(myPartitioner.getComputeTime().size(), expectedComputeTime.size());
302+
for (auto &el : myPartitioner.getComputeTime()) {
303+
Node *n = el.first;
304+
float expected = expectedComputeTime[nodeNamesMap[n].c_str()];
305+
float res = el.second;
306+
ASSERT_EQ(expected, res);
307+
}
308+
309+
ASSERT_EQ(mod_.getFunctions().size(), 3);
310+
ASSERT_EQ(myList.roots.size(), 1);
311+
}

0 commit comments

Comments
 (0)