Skip to content

[Partitioner] Partitioner Refactoring 2 #3318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions include/glow/Partitioner/Partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#ifndef GLOW_PARTITIONER_PARTITIONER_H
#define GLOW_PARTITIONER_PARTITIONER_H

#include "glow/Partitioner/PartitionerUtils.h"
#include "glow/Partitioner/PartitionerTypes.h"
#include "glow/Support/Error.h"

namespace glow {
Expand All @@ -26,9 +26,6 @@ using namespace runtime;
/// Given a module, partitions each of the its functions into multiple ones
/// based on memory constraints and minimizes the communication cost.
class Partitioner {
using MemUsageMap = std::unordered_map<Node *, uint64_t>;
using ComputeTimeMap = std::unordered_map<Node *, float>;

/// The module that needs to be decomposed.
Module *module_;

Expand Down Expand Up @@ -59,12 +56,6 @@ class Partitioner {
/// Total memory (bytes) requested by one module.
uint64_t memSize_;

/// The map of each operator and the corresponding memory size.
MemUsageMap memUsage_;

/// The map of each operator and the compute runtime.
ComputeTimeMap computeTime_;

/// Flag to set if the Partitioner should attempt to saturate the host, and
/// use all available devices.
bool saturateHost_;
Expand All @@ -81,12 +72,6 @@ class Partitioner {
/// update the memSize.
static Function *selectRepFunc(Module *parent, uint64_t &memSize);

/// Get the minimal memory requirement for each op in the function \p F
void initOpMemUsage(Function *F);

/// Inititalize the minimal compute time for each op in the function \p F.
void initOpComputeTime(Function *F);

/// After getting the initial partitions, adjust the partitions to minimize
/// communication and computation cost.
void partitionsAdjust(NodeToFunctionMap &partitions,
Expand Down Expand Up @@ -180,20 +165,6 @@ class Partitioner {
/// a function family and they have the same partition, we only dump the one
/// function's partition.
void dumpDAG(llvm::StringRef dotFilename) const;

/// Get function for computeTime_
float getComputeTime(Node *N) const {
auto it = computeTime_.find(N);
assert(it != computeTime_.end());
return it == computeTime_.end() ? 0.0 : it->second;
}

/// Get function for memUsage_
uint64_t getMemUsage(Node *N) const {
auto it = memUsage_.find(N);
assert(it != memUsage_.end());
return it == memUsage_.end() ? 0 : it->second;
}
};
} // namespace glow
#endif // GLOW_PARTITIONER_PARTITIONER_H
11 changes: 11 additions & 0 deletions include/glow/Partitioner/PartitionerTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ struct BackendInfo {
size_t num = 0;
/// The memory constraints for this backend.
uint64_t memSize;
/// The following peakCompute, peakDramBw, peakSramBw, peakPCIeBw are from
/// DeviceInfo_. Available SRAM capacity in bytes.
uint64_t sramCapacity;
/// Peak compute on device in ops/second. Assumes all ops are in int8.
float peakCompute;
/// Peak memory bandwidth from DRAM on device in bytes/second.
float peakDramBw;
/// Peak memory bandwidth from SRAM on device in bytes/second.
float peakSramBw;
/// Peak ingress/egress PCI-E bandwidth from device in bytes/second.
float peakPCIeBw;
/// Backend pointer.
Backend *backend = nullptr;
/// The non-supported nodes kind.
Expand Down
8 changes: 7 additions & 1 deletion include/glow/Partitioner/PartitionerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ std::vector<Node *> getOutUsersWithOnePredecessor(const NodesSet &nodes);
/// in the set \p nodes.
uint64_t getOutMemPerNode(const NodesSet &nodes, const Node *node);

/// Given a node, \return the NodeSet of inputs of this node.
/// Given a node, \returns the NodeSet of inputs of this node.
NodesSet getInputs(const Node *node);

/// Return the estimated op computation time based on \p backendInfo.
float getNodeComputeTime(const Node *node, const BackendInfo &backendInfo);

/// Given a node, \returns the memory usage of its inputs (i.e. Storage input).
uint64_t getNodeMemUsage(const Node *node);

/// Given nodes set \p currNodes and its memory usage info \p info, \returns the
/// new memory usage if \p newNode is added into \p currNodes.
GraphMemInfo updateGraphMemInfoByAddingNode(const NodesSet &currNodes,
Expand Down
241 changes: 11 additions & 230 deletions lib/Partitioner/Partitioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "glow/Partitioner/Partitioner.h"
#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
#include "glow/Partitioner/PartitionerOptimizer.h"
#include "glow/Partitioner/PartitionerUtils.h"
#include "glow/Partitioner/PartitionerValidation.h"
#include "glow/Support/Support.h"

Expand Down Expand Up @@ -174,227 +175,6 @@ Function *Partitioner::selectRepFunc(Module *parent, uint64_t &memSize) {
return ret;
}

/// Get the minimal memory requirement (constant) for each op in the function.
void Partitioner::initOpMemUsage(Function *F) {
memUsage_.clear();
for (auto &node : F->getNodes()) {
int n = node.getNumInputs();
uint64_t size = 0;
if (node.getKind() == Kinded::Kind::SaveNodeKind) {
memUsage_[&node] = size;
continue;
}
for (int i = 0; i < n; i++) {
Storage *in = llvm::dyn_cast<Storage>(node.getNthInput(i).getNode());
if (in) {
auto ty = in->getType();
size += ty->getSizeInBytes();
}
}
memUsage_[&node] = size;
}
}

/// Get the minimal compute time for each op in the function.
void Partitioner::initOpComputeTime(Function *F) {
computeTime_.clear();

// This code assumes all ops are BW limited from SRAM; except
// if the input does not fit in SRAM -- then it is DRAM BW limited
float peakDramBw = deviceInfo_[0].peakDramBw;
float peakSramBw = deviceInfo_[0].peakSramBw;
uint64_t sramCapacity = deviceInfo_[0].sramCapacity;
float peakCompute = deviceInfo_[0].peakCompute;

for (auto &node : F->getNodes()) {
// compute memory side bytes for inputs from DRAM, SRAM.
// TODO: think about whether this is better off computed inside a Node.

int n = node.getNumInputs();
uint64_t sizeDram = 0;
uint64_t sizeSram = 0;
if (node.getKind() == Kinded::Kind::SaveNodeKind) {
computeTime_[&node] = 0.0f;
continue;
}

// The memory bytes for embedding table lookups is data dependent,
// so it needs to be calculated as per the number of indices accessed.
if (node.getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind) {
auto *SLWSN = llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node);
// compute how many entries of the embedding table we look up
auto numLookups = SLWSN->getIndices().dims().front();
// compute how many bytes we read per lookup
auto tableSize = SLWSN->getData().getType()->getSizeInBytes();
auto numRows = SLWSN->getData().dims().front();
auto sizePerLookup = tableSize / numRows;
// compute total bytes read
uint64_t sizeInput = numLookups * sizePerLookup;

// tables are usually large and fit in DRAM
sizeDram += sizeInput;
// we also read the indices, weights and lengths arrays
sizeSram += SLWSN->getIndices().getType()->getSizeInBytes();
sizeSram += SLWSN->getWeights().getType()->getSizeInBytes();
sizeSram += SLWSN->getLengths().getType()->getSizeInBytes();
} else if (node.getKind() == Kinded::Kind::SparseLengthsSumNodeKind) {
auto *SLSN = llvm::dyn_cast<SparseLengthsSumNode>(&node);
// compute how many entries of the embedding table we look up
auto numLookups = SLSN->getIndices().dims().front();
// compute how many bytes we read per lookup
auto tableSize = SLSN->getData().getType()->getSizeInBytes();
auto numRows = SLSN->getData().dims().front();
auto sizePerLookup = tableSize / numRows;
// compute total bytes read
uint64_t sizeInput = numLookups * sizePerLookup;

// tables are usually large and fit in DRAM
sizeDram += sizeInput;
// we also read the indices and lengths arrays
sizeSram += SLSN->getIndices().getType()->getSizeInBytes();
sizeSram += SLSN->getLengths().getType()->getSizeInBytes();
} else if (node.getKind() ==
Kinded::Kind::
FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind) {
auto *FRQSLWSN =
llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
&node);
// compute how many entries of the embedding table we look up
auto numLookups = FRQSLWSN->getIndices().dims().front();
// compute how many bytes we read per lookup
auto tableSize = FRQSLWSN->getData().getType()->getSizeInBytes();
auto numRows = FRQSLWSN->getData().dims().front();
auto sizePerLookup = tableSize / numRows;
// compute total bytes read
uint64_t sizeInput = numLookups * sizePerLookup;

// tables are usually large and fit in DRAM
sizeDram += sizeInput;

// we also read the indices, weights and lengths arrays
sizeSram += FRQSLWSN->getIndices().getType()->getSizeInBytes();
sizeSram += FRQSLWSN->getWeights().getType()->getSizeInBytes();
sizeSram += FRQSLWSN->getLengths().getType()->getSizeInBytes();
} else if (node.getKind() ==
Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind) {
auto *FRQSLSN =
llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>(&node);
// compute how many entries of the embedding table we look up
auto numLookups = FRQSLSN->getIndices().dims().front();
// compute how many bytes we read per lookup
auto tableSize = FRQSLSN->getData().getType()->getSizeInBytes();
auto numRows = FRQSLSN->getData().dims().front();
auto sizePerLookup = tableSize / numRows;
// compute total bytes read
uint64_t sizeInput = numLookups * sizePerLookup;

// tables are usually large and fit in DRAM
sizeDram += sizeInput;

// we also read the indices and lengths arrays
sizeSram += FRQSLSN->getIndices().getType()->getSizeInBytes();
sizeSram += FRQSLSN->getLengths().getType()->getSizeInBytes();
} else {
// for all other ops, iterate through all inputs and get size in bytes
for (int i = 0; i < n; i++) {
auto ty = node.getNthInput(i).getType();
uint64_t sizeInput = ty->getSizeInBytes();
if (sizeInput > sramCapacity) {
sizeDram += sizeInput;
} else {
sizeSram += sizeInput;
}
}
}

// Repeat for outputs
for (size_t i = 0, e = node.getNumResults(); i < e; i++) {
auto myty = node.getType(i);
uint64_t sizeOutput = myty->getSizeInBytes();
if (sizeOutput > sramCapacity) {
sizeDram += sizeOutput;
} else {
sizeSram += sizeOutput;
}
}

// Calculate compute ops. Currently only computed for Matmul, Conv, FC
// TODO: think about whether this is better off computed inside a Node.
uint64_t totalOps = 0;
switch (node.getKind()) {
case Kinded::Kind::MatMulNodeKind: {
auto *MMN = llvm::dyn_cast<MatMulNode>(&node);
auto lhsDims = MMN->getLHS().dims();
auto rhsDims = MMN->getRHS().dims();
totalOps = 2 * lhsDims[0] * lhsDims[1] * rhsDims[1];
break;
}
case Kinded::Kind::FullyConnectedNodeKind: {
auto *FCN = llvm::dyn_cast<FullyConnectedNode>(&node);
auto inputDims = FCN->getInput().dims();
auto wtDims = FCN->getWeights().dims();
totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0];
break;
}
#ifdef GLOW_WITH_HABANA
case Kinded::Kind::HabanaFullyConnectedNodeKind: {
auto *FCN = llvm::dyn_cast<HabanaFullyConnectedNode>(&node);
auto inputDims = FCN->getInput().dims();
auto wtDims = FCN->getWeights().dims();
totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0];
break;
}
#endif
case Kinded::Kind::ConvolutionNodeKind: {
auto *CN = llvm::dyn_cast<ConvolutionNode>(&node);
auto resultDims = CN->getResult().dims();
// Get the product of batch, output height, output dims, output channels
totalOps = resultDims[0];
for (size_t i = 1, e = resultDims.size(); i < e; i++) {
totalOps *= resultDims[i];
}
// Multiply in kernel height, kernel width
auto kernelDims = CN->getKernels();
totalOps *= kernelDims[0] * kernelDims[1];
// Multiply in input channels/groups
auto inputChannels = CN->getInput().dims()[1];
auto nGroups = CN->getGroup();
totalOps *= (inputChannels * 1.0 / nGroups);
break;
}
#ifdef GLOW_WITH_HABANA
case Kinded::Kind::HabanaConvolutionNodeKind: {
auto *CN = llvm::dyn_cast<HabanaConvolutionNode>(&node);
auto resultDims = CN->getResult().dims();
// Get the product of batch, output height, output dims, output channels
totalOps = resultDims[0];
for (size_t i = 1, e = resultDims.size(); i < e; i++) {
totalOps *= resultDims[i];
}
// Multiply in kernel height, kernel width
auto kernelDims = CN->getKernels();
totalOps *= kernelDims[0] * kernelDims[1];
// Multiply in input channels/groups
auto inputChannels = CN->getInput().dims()[1];
auto nGroups = CN->getGroup();
totalOps *= (inputChannels * 1.0 / nGroups);
break;
}
#endif
default:
break;
}

// Compute compute roofline as max of flops, DRAM, SRAM BW
// See https://bit.ly/2UdJ3mz
// Add epsilons to prevent seg faults on uninitialized peak values.
computeTime_[&node] =
std::max(totalOps * 1.0f / std::max(peakCompute, 1e-6f),
std::max(sizeDram * 1.0f / std::max(peakDramBw, 1e-6f),
sizeSram * 1.0f / std::max(peakSramBw, 1e-6f)));
}
}

void Partitioner::partitionsAdjust(NodeToFunctionMap &partitions,
uint64_t availableMemory) {
// For each partition, create a node set.
Expand Down Expand Up @@ -743,6 +523,10 @@ void Partitioner::getBackendMap(
// is the same.
// TODO : will improve the algorithm for different memory size.
backendInfo.memSize = deviceInfo_[i].availableMemory;
backendInfo.peakDramBw = deviceInfo_[i].peakDramBw;
backendInfo.peakSramBw = deviceInfo_[i].peakSramBw;
backendInfo.sramCapacity = deviceInfo_[i].sramCapacity;
backendInfo.peakCompute = deviceInfo_[i].peakCompute;
backendInfo.nonSupportedNodesKinds =
generateNodeKindsSet(deviceInfo_[i].nonSupportedNodes);
backendInfo.supportedNodesKinds =
Expand Down Expand Up @@ -827,7 +611,8 @@ llvm::Error Partitioner::loadBalancedPartitioning(Function *F,
// Compute total roofline time
float totalRooflineTime = 0;
for (auto &n : F->getNodes()) {
totalRooflineTime += getComputeTime(&n);
totalRooflineTime +=
getNodeComputeTime(&n, backendMap_[deviceInfo_[0].backendName]);
}

float timePerPartition = totalRooflineTime / numDevices;
Expand Down Expand Up @@ -869,8 +654,9 @@ llvm::Error Partitioner::loadBalancedPartitioning(Function *F,
}
}

auto curOpTime = getComputeTime(N);
auto curOpMemory = getMemUsage(N);
auto curOpTime =
getNodeComputeTime(N, backendMap_[deviceInfo_[0].backendName]);
auto curOpMemory = getNodeMemUsage(N);

// Find a partition to put this node into
int curPartition = maxLogicalDeviceId;
Expand Down Expand Up @@ -1001,12 +787,7 @@ llvm::Error Partitioner::Partition(CompilationContext &cctx) {
RETURN_IF_ERR(::glow::optimizeFunction(func, *backend, cctx));
}

// Step 2.2 : get the min memory usage and the roofline memory bandwidth
// estimate for each op.
initOpMemUsage(func);
initOpComputeTime(func);

// Step 2.3 : apply graph partitioning algrithm to find out the partition.
// Step 2.2 : apply graph partitioning algrithm to find out the partition.
NodeToFunctionMap partitionMap =
selectPartitions(func, availMem, i->second);
mapping.insert(partitionMap);
Expand Down
Loading