-
Notifications
You must be signed in to change notification settings - Fork 699
[Partitioner] Add cost functions to partitioner #2441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,126 @@ void Partitioner::initOpMemUsage() { | |
} | ||
} | ||
|
||
/// Get the minimal compute time for each op in the function. | ||
void Partitioner::initOpComputeTime() { | ||
computeTime_.clear(); | ||
|
||
// This code assumes all ops are BW limited from SRAM; except | ||
// if the input does not fit in SRAM -- then it is DRAM BW limited | ||
float peakDramBw = deviceInfo_[0].peakDramBw; | ||
float peakSramBw = deviceInfo_[0].peakSramBw; | ||
size_t sramCapacity = deviceInfo_[0].sramCapacity; | ||
float peakCompute = deviceInfo_[0].peakCompute; | ||
|
||
for (auto &node : F_->getNodes()) { | ||
/// compute memory side bytes for inputs from DRAM, SRAM. | ||
/// TODO: think about whether this is better off computed inside a Node. | ||
|
||
int n = node.getNumInputs(); | ||
uint64_t sizeDram = 0; | ||
uint64_t sizeSram = 0; | ||
if (node.getKind() == Kinded::Kind::SaveNodeKind) { | ||
computeTime_[&node] = 0.0f; | ||
continue; | ||
} | ||
|
||
/// The memory bytes for embedding table lookups is data dependent, | ||
/// so it needs to be calculated as per the number of indices accessed. | ||
if (node.getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind) { | ||
auto *SLWSN = llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node); | ||
/// compute how many entries of the embedding table we look up | ||
auto numLookups = SLWSN->getIndices().getNode()->dims(0).front(); | ||
/// compute how many bytes we read per lookup | ||
auto tableSize = SLWSN->getData().getNode()->getType(0)->getSizeInBytes(); | ||
auto numRows = SLWSN->getData().getNode()->dims(0).front(); | ||
auto sizePerLookup = tableSize / numRows; | ||
/// compute total bytes read | ||
uint64_t sizeInput = numLookups * sizePerLookup; | ||
|
||
/// does the table fit in SRAM or DRAM | ||
if (tableSize > sramCapacity) { | ||
sizeDram += sizeInput; | ||
} else { | ||
sizeSram += sizeInput; | ||
} | ||
|
||
/// we also read the indices, weights and lengths arrays | ||
sizeSram += SLWSN->getIndices().getNode()->getType(0)->getSizeInBytes(); | ||
sizeSram += SLWSN->getWeights().getNode()->getType(0)->getSizeInBytes(); | ||
sizeSram += SLWSN->getLengths().getNode()->getType(0)->getSizeInBytes(); | ||
} else { | ||
/// for all other ops, iterate through all inputs and get size in bytes | ||
for (int i = 0; i < n; i++) { | ||
auto ty = node.getNthInput(i).getNode()->getType(0); | ||
uint64_t sizeInput = ty->getSizeInBytes(); | ||
if (sizeInput > sramCapacity) { | ||
sizeDram += sizeInput; | ||
} else { | ||
sizeSram += sizeInput; | ||
} | ||
} | ||
} | ||
|
||
// Repeat for outputs | ||
if (node.getNumResults() > 0) { | ||
auto myty = node.getType(0); | ||
uint64_t sizeOutput = myty->getSizeInBytes(); | ||
if (sizeOutput > sramCapacity) { | ||
sizeDram += sizeOutput; | ||
} else { | ||
sizeSram += sizeOutput; | ||
} | ||
} | ||
|
||
/// Calculate compute ops. Currently only computed for Matmul, Conv, FC | ||
/// TODO: think about whether this is better off computed inside a Node. | ||
uint64_t totalOps = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just want to double check again here: in the future, do we need to add the computation for each node? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we do. At least for memory bytes if not flops. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But for most ops, flops is less important. There are only a handful of ops here that will be at all compute bound. |
||
switch (node.getKind()) { | ||
case Kinded::Kind::MatMulNodeKind: { | ||
auto *MMN = llvm::dyn_cast<MatMulNode>(&node); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer using "switch". If we need to add more node type here, "switch" looks better:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. |
||
auto lhsDims = MMN->getLHS().dims(); | ||
auto rhsDims = MMN->getRHS().dims(); | ||
totalOps = 2 * lhsDims[0] * lhsDims[1] * rhsDims[1]; | ||
break; | ||
} | ||
case Kinded::Kind::FullyConnectedNodeKind: { | ||
auto *FCN = llvm::dyn_cast<FullyConnectedNode>(&node); | ||
auto inputDims = FCN->getInput().dims(); | ||
auto wtDims = FCN->getWeights().dims(); | ||
totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[1]; | ||
break; | ||
} | ||
case Kinded::Kind::ConvolutionNodeKind: { | ||
auto *CN = llvm::dyn_cast<ConvolutionNode>(&node); | ||
auto resultDims = CN->getResult().dims(); | ||
// Get the product of batch, output height, output dims, output channels | ||
totalOps = resultDims[0]; | ||
for (size_t i = 1, e = resultDims.size(); i < e; i++) { | ||
totalOps *= resultDims[i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here "i" should be size_t, otherwise, the type check will fail. |
||
} | ||
// Multiply in kernel height, kernel width | ||
auto kernelDims = CN->getKernels(); | ||
totalOps *= kernelDims[0] * kernelDims[1]; | ||
// Multiply in input channels/groups | ||
auto inputChannels = CN->getInput().dims()[1]; | ||
auto nGroups = CN->getGroup(); | ||
totalOps *= (inputChannels * 1.0 / nGroups); | ||
break; | ||
} | ||
default: | ||
break; | ||
} | ||
|
||
/// Compute compute roofline as max of flops, DRAM, SRAM BW | ||
/// See https://bit.ly/2UdJ3mz | ||
/// Add epsilons to prevent seg faults on unitialized peak values | ||
computeTime_[&node] = | ||
std::max(totalOps * 1.0f / std::max(peakCompute, 1e-6f), | ||
std::max(sizeDram * 1.0f / std::max(peakDramBw, 1e-6f), | ||
sizeSram * 1.0f / std::max(peakSramBw, 1e-6f))); | ||
} | ||
} | ||
|
||
// Combine the partitions if necessary : if all outside uses of the nodes in | ||
// partition1 is in partition2, and the sum of memory consumption of partition1 | ||
// and partition2 is less than availableMemory, combine partition1 and | ||
|
@@ -403,7 +523,6 @@ DAGNodeList &Partitioner::Partition() { | |
|
||
// Find the representive function for running partitioning algrithm. | ||
F_ = selectRepFunc(module_, memSize_); | ||
|
||
size_t availMem = deviceInfo_[0].availableMemory; | ||
|
||
if (memSize_ < availMem) { | ||
|
@@ -427,9 +546,8 @@ DAGNodeList &Partitioner::Partition() { | |
// Prepare 1: Get the min memory usage for each op. | ||
initOpMemUsage(); | ||
|
||
// Prepare 2: TODO: get the minimal comunication cost for any 2 ops (i.e. the | ||
// output data size) Will calculate it on the fly. -- Will double check which | ||
// way is better. | ||
// Prepare 2: Get the roofline memory bandwidth estimate for each op. | ||
initOpComputeTime(); | ||
|
||
// Partition | ||
// Use BFS to do the initial partitioning. Starting from the final node, BFS | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity: We assume here DRAM and SRAM as the only two layer of the memory hierarchy and it is fine for now. But do we need to support more levels/kinds of the memory hierarchy in a general case (e.g. different caches, etc)?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally yes. In general, it would be good to get the characteristics of the architecture from querying some API. This current design is the first step towards that.