diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 8fcd29f7a8..70175f0131 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -181,34 +181,15 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo seg_block.registerOutput(mini_graph_input); } } - // if no output, then register the last node's output as current graph's output + // if no output, then register this graph's input as its output + // We can ensure that TRT segmented block has Tensor inputs now if (seg_block.raw_outputs().empty()) { - // for Torch segments, register input as output - if (seg_block.target() == SegmentedBlock::kTorch) { - seg_block.registerOutput(seg_block.raw_inputs()[0]); - } else { - // for TensorRT segments, register last nonInput Tensor outputs - for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) { - for (auto node_output : seg_block.raw_nodes()[i]->outputs()) { - if (isTensor(node_output)) - seg_block.registerOutput(node_output); - } - if (!seg_block.raw_outputs().empty()) - break; - } - } + seg_block.registerOutput(seg_block.raw_inputs()[0]); } } std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) { torch::jit::EliminateDeadCode(seg_block.g()); }); - // erase segments which still have no output - segmented_blocks.erase( - std::remove_if( - segmented_blocks.begin(), - segmented_blocks.end(), - [](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }), - segmented_blocks.end()); return; }