diff --git a/lib/Partitioner/Partitioner.cpp b/lib/Partitioner/Partitioner.cpp index 9877ff3688..1e32e5c9cb 100644 --- a/lib/Partitioner/Partitioner.cpp +++ b/lib/Partitioner/Partitioner.cpp @@ -145,8 +145,8 @@ NodeToFunctionMap Partitioner::selectPartitions(Function *F, NodeToFunctionMap mapping; BFSLevel bfs = getBFSLevel(F); unsigned level = bfs.levels.size(); - // A list of cut. The graph can be partitioned by levels [level - 1, - // cut[0]), [cut[0] - 1, cut[1]), ..., [cut[n], -1). + // A list of cut. The graph can be partitioned by levels (cut[0], level - 1], + // (cut[1], cut[0] - 1], ..., (-1, cut[n] - 1]. std::vector cut; // Step 1 : get the initial cut based on BFS levels and avaiableMemory. @@ -159,10 +159,11 @@ NodeToFunctionMap Partitioner::selectPartitions(Function *F, tmp += memUsage_[N]; } if (mem + tmp > availableMemory) { + // mem == 0 means the mem usage for one level exceeds the availableMem, + // accept it now and will do adjustment later. Otherwise, leave tmp to + // next stage by assigning it to mem. if (mem == 0) { - // This means the mem usage for one level exceeds the availableMem, - // accept it now and will do adjustment later. - cut.push_back(i + 1); + cut.push_back(i - 1); } else { cut.push_back(i); mem = tmp; @@ -176,13 +177,24 @@ NodeToFunctionMap Partitioner::selectPartitions(Function *F, cut.push_back(-1); // Step 2 : Create the initial mapping between node and functions. + int color = 0; + Function *newF; for (int k = 0, e = cut.size(); k < e; k++) { - auto *newF = F->getParent()->createFunction(std::string(F->getName()) + - "_part" + std::to_string(k)); + newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" + + std::to_string(++color)); mapping.createPartition(newF); + unsigned mem = 0; for (int i = k > 0 ? cut[k - 1] : level - 1; i > cut[k]; i--) { for (int j = 0, e1 = bfs.levels[i].second.size(); j < e1; j++) { Node *N = bfs.levels[i].second[j]; + if (mem + memUsage_[N] > availableMemory) { + newF = F->getParent()->createFunction( + std::string(F->getName()) + "_part" + std::to_string(++color)); + mapping.createPartition(newF); + mem = memUsage_[N]; + } else { + mem += memUsage_[N]; + } mapping.add(N, newF); } } @@ -308,11 +320,9 @@ DAGNodeList &Partitioner::Partition() { // Find the representive function for running partitioning algrithm. F_ = selectRepFunc(module_, memSize_); - // Possible minimal k devices for a successful partitioning - // Note: here 2 is for testing; - unsigned k = 2; //(memSize_ + MARGIN) / devices[0].availableMemory; + unsigned availMem = deviceInfo_[0].availableMemory; - if (k == 1) { + if (memSize_ < availMem) { // No partition is needed. Create DAGNode and return. This root is alway a // dummy function. for (auto F : module_->getFunctions()) { @@ -340,9 +350,7 @@ DAGNodeList &Partitioner::Partition() { // Partition // Use BFS to do the initial partitioning. Starting from the final node, BFS // until the memory limitation reached one by one. - unsigned unitMem = memSize_ / k; // used for testing - - NodeToFunctionMap partitionMap = selectPartitions(F_, unitMem); + NodeToFunctionMap partitionMap = selectPartitions(F_, availMem); doPartitioning(F_, partitionMap); diff --git a/tests/unittests/PartitionerTest.cpp b/tests/unittests/PartitionerTest.cpp index e8d57d9102..0ab51752c1 100644 --- a/tests/unittests/PartitionerTest.cpp +++ b/tests/unittests/PartitionerTest.cpp @@ -65,25 +65,48 @@ static void executeDAG(DAGNode *G, Module &mod, Context &ctx, } } -TEST_F(PartitionerTest, test1) { +/// This one tests the model with this feature: after BFS, the memory +/// comsumption of all the nodes in each level won't exceed the device memory +/// constraints. +TEST_F(PartitionerTest, Basic1) { auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false); + auto *w1 = mod_.createConstant(ElemKind::FloatTy, {32, 16}, "w1"); + auto *b1 = mod_.createConstant(ElemKind::FloatTy, {16}, "b1"); ctx_.allocate(input); + w1->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b1->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); // Initial FC. - Node *I = F_->createFullyConnected(ctx_, "initial_fc", input, 16); + Node *I = F_->createFullyConnected("initial_fc", input, w1, b1); I = F_->createSigmoid("initial_sigmoid", I); // Left branch. - Node *L = F_->createFullyConnected(ctx_, "left_fc1", I, 16); + auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2"); + auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2"); + w2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + Node *L = F_->createFullyConnected("left_fc1", I, w2, b2); L = F_->createSigmoid("left_sigmoid1", L); - L = F_->createFullyConnected(ctx_, "left_fc2", L, 8); + auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3"); + auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3"); + w3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + L = F_->createFullyConnected("left_fc2", L, w3, b3); L = F_->createSigmoid("left_sigmoid2", L); // Right branch. - Node *R = F_->createFullyConnected(ctx_, "right_fc1", I, 16); + auto *w4 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w4"); + auto *b4 = mod_.createConstant(ElemKind::FloatTy, {16}, "b4"); + w4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + Node *R = F_->createFullyConnected("right_fc1", I, w4, b4); R = F_->createSigmoid("right_sigmoid1", R); - R = F_->createFullyConnected(ctx_, "right_fc2", R, 8); + auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5"); + auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5"); + w5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + R = F_->createFullyConnected("right_fc2", R, w5, b5); R = F_->createSigmoid("right_sigmoid2", R); // Join branches. @@ -100,7 +123,76 @@ TEST_F(PartitionerTest, test1) { EE.run(ctx_); Tensor ref = res.clone(); - std::vector devices; + std::vector devices = {{3072}, {3072}, {3072}}; + Partitioner myPartitioner(&mod_, devices); + + DAGNodeList myList = std::move(myPartitioner.Partition()); + ASSERT_EQ(mod_.getFunctions().size(), 3); + ASSERT_EQ(myList.roots.size(), 1); + + // Run the paritioned graph and compare the results. + ctx_.allocate(mod_.getPlaceholders()); + for (auto it = myList.roots.begin(); it != myList.roots.end(); ++it) { + ctx_.allocate(mod_.getPlaceholders()); + executeDAG((*it).get(), mod_, ctx_, {input}, {&in}); + Tensor test = res.clone(); + EXPECT_TRUE(ref.isEqual(test)); + } +} + +/// This one tests the model with this feature: after BFS, there is one level, +/// the memory comsumption of all the nodes in which exceeds the device memory +/// constraints. +TEST_F(PartitionerTest, Basic2) { + auto *input = + mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input", false); + auto *input1 = + mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input1", false); + ctx_.allocate(input); + ctx_.allocate(input1); + // Left branch. + auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2"); + auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2"); + w2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b2->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + Node *L = F_->createFullyConnected("left_fc1", input, w2, b2); + L = F_->createSigmoid("left_sigmoid1", L); + auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3"); + auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3"); + w3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b3->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + L = F_->createFullyConnected("left_fc2", L, w3, b3); + L = F_->createSigmoid("left_sigmoid2", L); + + // Right branch. + auto *w4 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w4"); + auto *b4 = mod_.createConstant(ElemKind::FloatTy, {16}, "b4"); + w4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b4->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + Node *R = F_->createFullyConnected("right_fc1", input1, w4, b4); + R = F_->createSigmoid("right_sigmoid1", R); + auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5"); + auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5"); + w5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + b5->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); + R = F_->createFullyConnected("right_fc2", R, w5, b5); + R = F_->createSigmoid("right_sigmoid2", R); + + // Join branches. + auto *mul = F_->createMul("mul", L, R); + auto *save = F_->createSave("ret", mul); + auto &res = *ctx_.allocate(save->getPlaceholder()); + + // Infer using the un-partitioned graph. + Tensor in(ElemKind::FloatTy, {1, 16}); + ExecutionEngine EE; + + EE.compile(CompilationMode::Infer, F_); + updateInputPlaceholders(ctx_, {input, input1}, {&in, &in}); + EE.run(ctx_); + Tensor ref = res.clone(); + + std::vector devices = {{2048}, {2048}, {2048}}; Partitioner myPartitioner(&mod_, devices); DAGNodeList myList = std::move(myPartitioner.Partition());