Skip to content

Commit 809a3e8

Browse files
Man Wangfacebook-github-bot
Man Wang
authored andcommitted
User-defined partition (#3237)
Summary: This PR added user-defined partition flow. Basically, a struct "PartitionConfig" containing the partition info is passed into Partitioner to enable this flow. Now we let users have the full control of how to do the partitioning. To use this flow, users can write their helper function to generate PartitionConfig, and call Partitioner directly. In the following PR, we will add passing PartitionConfig through HostManager from a yaml file. Related to #2298 Documentation: [Optional Fixes #issue] Pull Request resolved: #3237 Test Plan: Added unittest. ninja test. Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md. Differential Revision: D16345755 Pulled By: beicy fbshipit-source-id: 64f601004880e1702f5f6eba19d4f5b7749f0864
1 parent c0d6d72 commit 809a3e8

File tree

4 files changed

+208
-52
lines changed

4 files changed

+208
-52
lines changed

include/glow/Partitioner/Partitioner.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,14 @@ class Partitioner {
204204
/// use all available devices.
205205
bool saturateHost_;
206206

207-
// Flag to set if the funcitons in the module are areadly optimized. By
208-
// default, the optimization should be done in Partitioner due to
209-
// heterogeneous partition.
207+
/// Flag to set if the funcitons in the module are areadly optimized. By
208+
/// default, the optimization should be done in Partitioner due to
209+
/// heterogeneous partition.
210210
bool optimized_;
211211

212+
/// The struct contain user-defined partition info.
213+
PartitionConfig partitionConfig_;
214+
212215
/// Get the representative function (the one with the largest input) and
213216
/// update the memSize.
214217
static Function *selectRepFunc(Module *parent, uint64_t &memSize);
@@ -295,17 +298,26 @@ class Partitioner {
295298
/// "Function Family", that is, without considerting the "dynamic stuff" (i.e.
296299
/// batch size, input/output shape of each op), all the functions are
297300
/// identical. The required memory and computation cost for each op can be
298-
/// found in Module. The \p devices provides the cost model related to
299-
/// devices.
301+
/// found in Module.
302+
/// The \p devices provides the cost model related to devices.
303+
/// Saturating the host will be enabled if \p saturateHost is true.
304+
/// \p optimized is false by default, which means the functions in this module
305+
/// are not optimized. \p partitionConfig contains the user defined partition
306+
/// info.
300307
Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
301-
bool saturateHost = false, bool optimized = false);
308+
bool saturateHost = false, bool optimized = false,
309+
PartitionConfig partitionConfig = PartitionConfig());
302310

303311
/// Users can create Mock Backends and pass their points to test Graph
304312
/// Partitioning without actually register them in GLOW.
305313
Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
306314
const std::vector<Backend *> &backends, bool saturateHost = false,
307315
bool optimized = false);
308316

317+
/// Based on partitionConfig_ passed into Partitioner, do the user-defined
318+
/// partition.
319+
llvm::Error PartitionFromConfig();
320+
309321
/// Decompose each function in a module. Now we support partitioning a module
310322
/// among different type of devices. \p cctx is used during optimization of
311323
/// the Function. \returns whether there was an error encountered.

include/glow/Runtime/RuntimeTypes.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,30 @@ struct HostConfig {
171171
size_t executorThreads{3};
172172
};
173173

174+
/// This is struct for user defined partition.
175+
struct PartitionConfig {
176+
/// The name of the function to be partitioned.
177+
std::string funcName;
178+
/// The number of user defined partitions.
179+
/// The partition ids are between 0 and numOfPartitions - 1, inclusive.
180+
size_t numOfPartitions;
181+
/// The backend for each partition. backendNames.size() == numOfPartitions.
182+
std::vector<std::string> backendNames;
183+
/// The name for each partition. partitionNames.size() == numOfPartitions.
184+
std::vector<std::string> partitionNames;
185+
/// The mapping between nodes' name to Partition ids. Assume there are n nodes
186+
/// and m partitions. We have 2 types of valid mapping: 1. all nodes are
187+
/// mapped to a partition. 2. For i-th (0 <= i < m) partition, the nodes
188+
/// mapped to this partition id are not in this map, and the nodes mapped to
189+
/// other partitions ids must be in this map. The node's name should be the
190+
/// name in Glow function and may be different from the original name from
191+
/// models. Since Glow will mangle names to make them unique.
192+
llvm::StringMap<size_t> nodeToPartition;
193+
194+
PartitionConfig() : numOfPartitions(0) {}
195+
bool enabled() { return numOfPartitions > 0; }
196+
};
197+
174198
} // namespace runtime
175199
} // namespace glow
176200
#endif // GLOW_RUNTIME_RUNTIMETYPES_H

lib/Partitioner/Partitioner.cpp

Lines changed: 113 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222
#include "llvm/Support/Format.h"
2323
#include "llvm/Support/raw_ostream.h"
2424

25-
#include <fstream>
26-
2725
#include "llvm/Support/CommandLine.h"
2826
#include "llvm/Support/raw_ostream.h"
2927

28+
#include <fstream>
3029
namespace glow {
3130
bool GlowEnableLoadBalancedPartitioning = false;
3231
static llvm::cl::opt<bool, /* ExternalStorage */ true>
@@ -62,6 +61,20 @@ bool sortMinMemory(const std::pair<Function *, uint64_t> &a,
6261
return a.second < b.second;
6362
}
6463

64+
static void dumpPartitionInfo(const NodeToFunctionMap &partitions) {
65+
int i = 0;
66+
for (Function *subF : partitions.getPartitions()) {
67+
LOG(INFO) << "\t Partition " << i++ << ":\n"
68+
<< "\t\t Name :\t" << subF->getName().str() << "\n"
69+
<< "\t\t BackendKind :\t"
70+
<< partitions.getPartitionBackendName(subF) << "\n"
71+
<< "\t\t Memory :\t"
72+
<< partitions.getGraphMemInfo(subF).getTotalMemSize() << "\n"
73+
<< "\t\t LogicalDeviceIDs :\t"
74+
<< partitions.getLogicalDeviceIDList(subF)[0] << "\n";
75+
}
76+
}
77+
6578
void Partitioner::dumpDAG(llvm::StringRef dotFilename) const {
6679
if (partitions_.size() == 0)
6780
return;
@@ -168,9 +181,10 @@ Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
168181
}
169182

170183
Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
171-
bool saturateHost, bool optimized)
184+
bool saturateHost, bool optimized,
185+
PartitionConfig partitionConfig)
172186
: module_(parent), deviceInfo_(devices), saturateHost_(saturateHost),
173-
optimized_(optimized) {
187+
optimized_(optimized), partitionConfig_(partitionConfig) {
174188
memSize_ = module_->getConstantsSize();
175189
logicalDeviceID_ = 0;
176190
}
@@ -1211,7 +1225,6 @@ llvm::Error Partitioner::QuantizationProfilingPartition(
12111225
module_->eraseFunction(F_);
12121226
std::unique_ptr<Backend> backend(createBackend(profilingBackend));
12131227
for (Function *subF : module_->getFunctions()) {
1214-
(void)subF;
12151228
assert(subF->verify() && "Conversion led to invalid function");
12161229
if (!optimized_) {
12171230
RETURN_IF_ERR(::glow::optimizeFunction(subF, *backend, cctx));
@@ -1231,6 +1244,11 @@ llvm::Error Partitioner::Partition(CompilationContext &cctx) {
12311244
std::vector<std::unique_ptr<Backend>> backendHolder;
12321245
getBackendMap(backendMap_, backendHolder, backends);
12331246

1247+
if (partitionConfig_.enabled()) {
1248+
// Jump into user-defined partition, and skip the following auto partition.
1249+
return PartitionFromConfig();
1250+
}
1251+
12341252
// Step 0: Find the representative function for running partitioning
12351253
// algorithm.
12361254
F_ = selectRepFunc(module_, memSize_);
@@ -1348,27 +1366,104 @@ llvm::Error Partitioner::Partition(CompilationContext &cctx) {
13481366
dumpDAG("DAG.dot");
13491367
}
13501368

1351-
int i = 0;
13521369
for (Function *subF : funcList) {
1353-
(void)subF;
1354-
if (logPartition) {
1355-
LOG(INFO) << "\t Partition " << i << ":\n"
1356-
<< "\t\t Name :\t" << subF->getName().str() << "\n"
1357-
<< "\t\t BackendKind :\t"
1358-
<< mapping.getPartitionBackendName(subF) << "\n"
1359-
<< "\t\t Memory :\t"
1360-
<< mapping.getGraphMemInfo(subF).getTotalMemSize() << "\n"
1361-
<< "\t\t LogicalDeviceIDs :\t"
1362-
<< mapping.getLogicalDeviceIDList(subF)[0] << "\n";
1363-
}
13641370
if (dumpPartition) {
13651371
subF->dumpDAG("partitionLogicalID" +
13661372
std::to_string(mapping.getLogicalDeviceIDList(subF)[0]) +
13671373
"__" + subF->getFilename() + "__" +
13681374
mapping.getPartitionBackendName(subF) + ".dot");
13691375
}
1370-
i++;
13711376
assert(subF->verify() && "Conversion led to invalid function");
13721377
}
1378+
if (logPartition) {
1379+
dumpPartitionInfo(mapping);
1380+
}
1381+
return llvm::Error::success();
1382+
}
1383+
1384+
llvm::Error Partitioner::PartitionFromConfig() {
1385+
Function *F = module_->getFunction(partitionConfig_.funcName);
1386+
RETURN_ERR_IF_NOT(F, strFormat("Can't find function %s in current module.",
1387+
F->getName().str().data()));
1388+
1389+
DCHECK(partitionConfig_.numOfPartitions ==
1390+
partitionConfig_.backendNames.size() &&
1391+
partitionConfig_.numOfPartitions ==
1392+
partitionConfig_.partitionNames.size())
1393+
<< "Invalid user-defined partition config.";
1394+
1395+
NodeToFunctionMap partitionMap;
1396+
std::vector<Function *> funcList;
1397+
std::unordered_set<size_t> unused;
1398+
std::vector<NodesSet> nodesSets(partitionConfig_.numOfPartitions);
1399+
// Create partitions based on the given number and names.
1400+
for (size_t i = 0; i < partitionConfig_.numOfPartitions; i++) {
1401+
Function *newF =
1402+
module_->createFunction(partitionConfig_.partitionNames[i]);
1403+
funcList.push_back(newF);
1404+
partitionMap.createPartition(newF, partitionConfig_.backendNames[i]);
1405+
unused.insert(i);
1406+
}
1407+
1408+
// Map the nodes the the partitions.
1409+
std::vector<Node *> unMapped;
1410+
for (auto &node : F->getNodes()) {
1411+
auto iter = partitionConfig_.nodeToPartition.find(node.getName());
1412+
if (iter == partitionConfig_.nodeToPartition.end()) {
1413+
// If a node in F is not in the node to partition mapping, put it into
1414+
// unMaped list.
1415+
unMapped.push_back(&node);
1416+
} else {
1417+
size_t partitionID = iter->second;
1418+
DCHECK(partitionID < partitionConfig_.numOfPartitions)
1419+
<< "Invalid partition id :" << partitionID;
1420+
partitionMap.add(&node, funcList[partitionID]);
1421+
unused.erase(partitionID);
1422+
nodesSets[partitionID].insert(&node);
1423+
}
1424+
}
1425+
1426+
// If there is unused partition and unmapped nodes, map those nodes to the
1427+
// unused partition.
1428+
if (unMapped.size()) {
1429+
DCHECK(unused.size() == 1) << "There must be exactly 1 unused partition.";
1430+
auto partitionID = *(unused.begin());
1431+
for (auto &node : unMapped) {
1432+
partitionMap.add(node, funcList[partitionID]);
1433+
nodesSets[partitionID].insert(node);
1434+
}
1435+
}
1436+
1437+
// Validate memory usage.
1438+
for (size_t i = 0; i < partitionConfig_.numOfPartitions; i++) {
1439+
GraphMemInfo cost = getGraphMemInfo(nodesSets[i]);
1440+
partitionMap.setGraphMemInfo(funcList[i], cost);
1441+
}
1442+
RETURN_IF_ERR(memoryUsageValidation(partitionMap));
1443+
1444+
// Logical device ID validation.
1445+
logicalDeviceID_ = assignLogicalDeviceID(partitionMap);
1446+
RETURN_IF_ERR(logicalDevicesValidation(partitionMap));
1447+
1448+
// TODO : loop-free validation.
1449+
1450+
// Do partition.
1451+
doPartitioning(F->getName(), {F}, partitionMap, true);
1452+
module_->eraseFunction(F);
1453+
1454+
// Do optimization based on backendName.
1455+
for (size_t i = 0; i < partitionConfig_.numOfPartitions; i++) {
1456+
auto func = funcList[i];
1457+
assert(func->verify() && "Conversion led to invalid function");
1458+
std::unique_ptr<Backend> backend(
1459+
createBackend(partitionConfig_.backendNames[i]));
1460+
if (!optimized_) {
1461+
CompilationContext cctx;
1462+
RETURN_IF_ERR(::glow::optimizeFunction(func, *backend, cctx));
1463+
}
1464+
}
1465+
if (logPartition) {
1466+
dumpPartitionInfo(partitionMap);
1467+
}
13731468
return llvm::Error::success();
13741469
}

0 commit comments

Comments
 (0)