diff --git a/core/compiler.cpp b/core/compiler.cpp index b684b808f5..4a4389bea3 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -198,7 +198,8 @@ void AddIfBlockToGraph( auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); }; new_if_block->cloneFrom(cur_block_graph->block(), env); - if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { + if (cur_block_graph->inputs().size() && + cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) { if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) { auto self = new_g->insertInput(0, "self_1"); self->setType(cur_block_graph->inputs()[0]->type()); @@ -223,13 +224,14 @@ GraphAndMapping ConstructFallbackGraph( torch::jit::Block* block, std::unordered_map example_tensor_map, CompileSpec cfg, - ir::StaticParams static_params) { + ir::StaticParams static_params, + std::unordered_map& fallback_nodes) { auto convert_cfg = cfg.convert_info; auto partition_info = cfg.partition_info; auto new_g = std::make_shared(); - auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info); + auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes); // the mapping from lowering graph => fallback global graph std::unordered_map old_to_new_g; @@ -270,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph( std::vector graph_and_mappings; for (auto cur_block : if_node->blocks()) { graph_and_mappings.push_back( - ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params)); + ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes)); } AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g); @@ -293,7 +295,7 @@ GraphAndMapping ConstructFallbackGraph( // Set the output as the produced tuple new_g->registerOutput(return_tuple_node->outputs()[0]); } else { - if (old_to_new_g.count(block->outputs()[0])) { + if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) { new_g->registerOutput(old_to_new_g[block->outputs()[0]]); } } @@ -430,7 +432,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) !(cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); - auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params); + std::unordered_map fallback_nodes; + auto graph_and_mapping = + ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes); new_g = graph_and_mapping.first; LOG_INFO("Segmented Graph: " << *new_g); diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 63161217e4..476d6fcfba 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -30,34 +30,49 @@ inline bool isTensor(torch::jit::Value* val) { return val->type()->isSubtypeOf(torch::jit::TensorType::get()); } -bool isAllNodesSupported(const std::vector& nodes) { - for (auto node : nodes) { - if (!conversion::OpSupported(node)) { - return false; +bool containNonTensorOutputs(torch::jit::Node* n) { + for (auto output : n->outputs()) { + if (!isTensorOrTensorList(output)) { + return true; } } - return true; + return false; } -bool containTargetInputs(torch::jit::Node* n, const std::unordered_set& target_inputs) { - for (auto input : n->inputs()) { - if (!isTensorOrTensorList(input) && target_inputs.count(input)) { - return true; +bool isModifyingNodes(torch::jit::Node* node, torch::jit::Value* val) { + const auto& schema = node->schema(); + for (size_t i = 0; i < node->inputs().size(); ++i) { + if (node->inputs()[i] == val) { + const at::AliasInfo* formal = schema.arguments()[i].alias_info(); + if (formal && formal->isWrite()) { + return true; + } } } return false; } -bool containNonTensorOutputs(torch::jit::Node* n) { - for (auto output : n->outputs()) { - if (!isTensorOrTensorList(output)) { - return true; +std::vector findModifyingNodes( + torch::jit::Value* val, + const std::unordered_set& seg_block_nodes) { + std::vector modifying_nodes; + for (auto use : val->uses()) { + torch::jit::Node* node = use.user; + if (seg_block_nodes.find(node) != seg_block_nodes.end()) { + break; + } + if (isModifyingNodes(node, val)) { + modifying_nodes.push_back(node); } } - return false; + return modifying_nodes; } -std::vector getDependencyNodes(const std::vector& vals) { +std::vector getDependencyNodes( + const std::vector& vals, + const SegmentedBlock& seg_block) { + // get all nodes in the segmentedblock + std::unordered_set seg_block_nodes(seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); // use bfs to get the DAG dependency nodes for input value std::queue> q( std::deque(vals.begin(), vals.end())); @@ -69,6 +84,8 @@ std::vector getDependencyNodes(const std::vectornode(); if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { visited.insert(node); + auto modifying_nodes = findModifyingNodes(cur_val, seg_block_nodes); + stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend()); stk.push_back(node); for (auto input : node->inputs()) { if (!isTensorOrTensorList(input)) { @@ -81,62 +98,29 @@ std::vector getDependencyNodes(const std::vector getOutputNodes( - torch::jit::Value* value, - const std::unordered_set& seg_block_nodes) { - // use bfs to get the DAG outputs nodes for input value - std::queue q; - std::vector stk; - std::unordered_set visited; - q.push(value); - - // top-down order traversing - while (!q.empty()) { - auto cur_val = q.front(); - q.pop(); - for (auto use : cur_val->uses()) { - auto node = use.user; - // use node must be in seg_block_nodes - if (seg_block_nodes.count(node) && !visited.count(node)) { - stk.push_back(node); - visited.insert(node); - // travel its' all outputs - for (auto output : node->outputs()) { - if (!isTensor(output)) { - q.push(output); - } - } - } - } - } - - // top-down order and we don't need to reverse it - return stk; -} - -void getDirtyNodes( - std::unordered_set& dirty_nodes, - const std::unordered_set& seg_block_nodes) { +void find_all_fallback_nodes(std::unordered_map& fallback_nodes) { std::queue q; - for (auto& node : dirty_nodes) { - q.push(node); + for (auto& node : fallback_nodes) { + q.push(node.first); } - dirty_nodes.clear(); + std::unordered_set visited_nodes; while (!q.empty()) { auto cur_node = q.front(); q.pop(); - if (!dirty_nodes.count(cur_node) && seg_block_nodes.count(cur_node)) { - dirty_nodes.insert(cur_node); - for (auto input : cur_node->inputs()) { - if (!isTensorOrTensorList(input)) { - q.push(input->node()); - } + // for every node that produces this fallback node's NonTensor input, they should fallback too + for (auto input : cur_node->inputs()) { + if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && + fallback_nodes.insert({input->node(), 4}).second) { + q.push(input->node()); } - for (auto output : cur_node->outputs()) { - if (!isTensorOrTensorList(output)) { - for (auto use : output->uses()) { - auto node = use.user; + } + // for every node that consumes this fallback node's NonTensor output, they should fallback too + for (auto output : cur_node->outputs()) { + if (!isTensor(output)) { + for (auto use : output->uses()) { + auto node = use.user; + if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) { q.push(node); } } @@ -145,254 +129,26 @@ void getDirtyNodes( } } -std::pair, SegmentedBlock> segmentBlocksWithTensorListInputs( - SegmentedBlock& seg_block, - const std::unordered_map& tensorlist_inputs) { - std::unordered_set all_append_nodes; - std::unordered_map append_blocks; - const std::unordered_set seg_block_nodes( - seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - for (auto input_pair : tensorlist_inputs) { - auto append_nodes = getOutputNodes(input_pair.first, seg_block_nodes); - append_blocks[input_pair.first] = SegmentedBlock(input_pair.second.target(), append_nodes); - all_append_nodes.insert(append_nodes.begin(), append_nodes.end()); - } - - std::vector trt_nodes; - for (auto node : seg_block.raw_nodes()) { - if (all_append_nodes.count(node) == 0) { - trt_nodes.emplace_back(node); - } - } - SegmentedBlock trt_block(SegmentedBlock::kTensorRT, trt_nodes); - - return std::pair, SegmentedBlock>(append_blocks, trt_block); -} - -PartitionedGraph segmentBlocksWithSpecifiedInputs( - SegmentedBlock& seg_block, - std::vector& inputs_to_resolve) { - std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve); - PartitionedGraph new_seg_blocks; - // if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the - // dependency nodes at the beginning of the current segmented_block and return this merged segmented_block - if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) { - // if current node is prim::If, just ensure that we have all required input in kTorch - if (seg_block.raw_nodes()[0]->kind() == torch::jit::prim::If) { - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); - new_seg_blocks.push_back(seg_block); - } else { - dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes); - } - } else { - // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again - std::unordered_set inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end()); - std::vector tensorrt_nodes, pytorch_nodes; - - // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use - // dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor - // values consumed by dirty nodes - std::unordered_set dirty_nodes; - const std::unordered_set seg_block_nodes( - seg_block.raw_nodes().begin(), seg_block.raw_nodes().end()); - - for (auto n : seg_block.raw_nodes()) { - if (containTargetInputs(n, inputs_to_resolve_set)) { - dirty_nodes.insert(n); - } - } - getDirtyNodes(dirty_nodes, seg_block_nodes); - for (auto n : seg_block.raw_nodes()) { - if (dirty_nodes.count(n)) { - if (!tensorrt_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); - tensorrt_nodes.clear(); +void resolveTRTNonTensorInputs(PartitionedGraph& segmented_blocks) { + // if a TRT segment has nonTensor Inputs, the nodes that produce this nonTensor Inputs must in another TensorRT engine + // because we have already found the interface between Torch and TRT in segmentation phase + // what we do here is just find the dependency nodes of the TRT segments that have nonTensor inputs + for (size_t i = 0; i < segmented_blocks.size(); ++i) { + if (segmented_blocks[i].target() == SegmentedBlock::kTensorRT) { + std::vector inputs_to_resolve; + for (auto input : segmented_blocks[i].raw_inputs()) { + if (!isTensor(input)) { + inputs_to_resolve.push_back(input); } - pytorch_nodes.push_back(n); - } else { - if (!pytorch_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes); - pytorch_nodes.clear(); - } - tensorrt_nodes.push_back(n); } - } - - // Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly. - if (!tensorrt_nodes.empty()) { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTensorRT, tensorrt_nodes); - } else { - new_seg_blocks.emplace_back(new_seg_blocks.size(), SegmentedBlock::kTorch, pytorch_nodes); - } - } - - return new_seg_blocks; -} - -PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) { - // reconstruct segmented_block if this block requires nonTensor input - std::vector inputs_to_resolve; - // Gather all non-tensor inputs for this block - for (auto input : seg_block.raw_inputs()) { - if (!isTensorOrTensorList(input)) { - inputs_to_resolve.push_back(input); - } - } - return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve); -} - -std::unordered_map getInputUsageCounts( - const PartitionedGraph& segmented_blocks, - const std::function& condition) { - // usage_counts is a map which stores non-tensor inputs as keys and the values are indices of segmented blocks which - // have these non-tensor inputs. Iterate through the graph (segmented blocks) from bottom to top. When we find a - // non-tensor input in a segmented block of index "i", store it in the usage_counts map. Now for each non-tensor - // inputs recorded in the usage_counts map, we check if any previous segmented block (segmented block index i goes - // from n-1 to 0) generated/contains this non-tensor input. If so, we set this idx as the produce_id as it produces - // the non-tensor input. - std::unordered_map usage_counts; - for (int i = segmented_blocks.size() - 1; i >= 0; --i) { - for (auto input : segmented_blocks[i].raw_inputs()) { - if (condition(input)) { - segmented_blocks[i].target() == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id.push_back(i) - : usage_counts[input].tensorrt_use_id.push_back(i); - } - } - - // For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block - // that has/produces it. - for (auto& use : usage_counts) { - // Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value - if (segmented_blocks[i].contain_raw_value(use.first)) { - use.second.produce_id = i; - } - } - } - return usage_counts; -} - -std::unordered_map::iterator> getIdxtoIterMap( - std::list& segmented_blocks_list) { - std::unordered_map::iterator> idx_to_iter; - auto iter = segmented_blocks_list.begin(); - for (uint64_t i = 0; i < segmented_blocks_list.size(); ++i, ++iter) { - idx_to_iter[i] = iter; - } - return idx_to_iter; -} - -void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) { - // get input usage counts and blocks_list - std::list segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend()); - auto usage_counts = getInputUsageCounts( - segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); }); - auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); - - std::map> - torch_values_to_fix; // Only need to resolve values generated by tensorrt - std::set tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs - - // update blocks_list - std::unordered_set updated_segments; - for (auto& use : usage_counts) { - auto use_info = use.second; - // if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first - // kTorch segment. - if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) { - auto first_torch_id = use_info.torch_use_id.back(); - torch_values_to_fix[first_torch_id].push_back(use.first); - } - // kTensorRT segments always need to inject nodes for the nonTensor inputs - for (auto i : use_info.tensorrt_use_id) { - tensorrt_blocks_to_fix.insert(i); - } - } - for (auto torch_block_pair : torch_values_to_fix) { - auto to_inject_blocks = - segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - } - - for (auto i : tensorrt_blocks_to_fix) { - auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]); - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end()); - } - - segmented_blocks.clear(); - segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); - return; -} - -void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) { - // usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which - // produces/contains it. - auto usage_counts = - getInputUsageCounts(segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); }); - - // Get idx of the segblock to its iterator mapping - std::list segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend()); - auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list); - - std::unordered_set updated_segments; - // we need to re-segment TensorRT segments whose inputs are TensorLists - for (auto& use : usage_counts) { - auto use_info = use.second; - // For a particular tensorlist input, traverse through all ids of segmented blocks whose target is TensorRT - for (auto i : use_info.tensorrt_use_id) { - if (!updated_segments.count(i)) { - // tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this - // tensorlist input} - std::unordered_map tensorlistinput_to_segblock; - for (auto input : segmented_blocks[i].raw_inputs()) { - if (isTensorList(input)) { - tensorlistinput_to_segblock[input] = segmented_blocks[usage_counts[input].produce_id]; - } - } - - // For each tensorlist input in tensorlistinput_to_segblock, get the node which actually uses this input. - // Once we retrieve the node, we remove it from the current TensorRT segmented_blocks[i]. This node should be - // added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first - // place. - auto seg_blocks = segmentBlocksWithTensorListInputs(segmented_blocks[i], tensorlistinput_to_segblock); - auto append_blocks = seg_blocks.first; - auto trt_block = seg_blocks.second; - // Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that - // uses tensorlist input removed. - auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]); - if (trt_block.raw_nodes().size() > 0) { - segmented_blocks_list.insert(next_iter, trt_block); - } - - // append blocks' nodes to the producer seg_block - for (auto append_block : append_blocks) { - auto input = append_block.first; // corresponds to the tensorlist input - auto block = append_block.second; - // append nodes to segmented_blocks_list - auto producer = idx_to_iter[usage_counts[input].produce_id]; - for (auto n : block.raw_nodes()) { - producer->cloneNode(n); - } - } - updated_segments.insert(i); + if (!inputs_to_resolve.empty()) { + std::vector dependency_nodes = getDependencyNodes(inputs_to_resolve, segmented_blocks[i]); + dependency_nodes.insert( + dependency_nodes.end(), segmented_blocks[i].raw_nodes().begin(), segmented_blocks[i].raw_nodes().end()); + segmented_blocks[i] = SegmentedBlock(SegmentedBlock::kTensorRT, dependency_nodes); } } } - segmented_blocks.clear(); - segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end()); - return; -} - -void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shared_ptr g - // make sure that all inputs should be tensor - LOG_DEBUG("Resolving nonTensor inputs/outputs of segmented_blocks"); - resolveNonTensorInputBlocks(segmented_blocks); - - // we need to re-segment tensorrt blocks whose inputs are tensorlists (eg: Tensor [] instead of Tensor). - LOG_DEBUG("Resolving inputs of type TensorList in segmented_blocks"); - resolveTensorListInputBlocks(segmented_blocks); } void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Block* block) { @@ -430,7 +186,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo // for TensorRT segments, register last nonInput Tensor outputs for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) { for (auto node_output : seg_block.raw_nodes()[i]->outputs()) { - if (isTensorOrTensorList(node_output)) + if (isTensor(node_output)) seg_block.registerOutput(node_output); } if (!seg_block.raw_outputs().empty()) @@ -467,23 +223,19 @@ bool checkLoopEvaluatable(torch::jit::Node* n) { return compile_to_trt; } -bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set& torch_ops) { - // If the op is not supported by the conversion phase it should run in PyTorch - if (!conversion::OpSupported(n)) { - LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); - return false; - } - - // If the user specifies the op to run in Torch it should run in PyTorch - if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) { - LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); - return false; - } - - // If the user specifies the module containing this op to run in torch it should run in PyTorch - const auto to_compile_sym = c10::Symbol::attr("to_compile"); - if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { - LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); +bool check_node_fallback(torch::jit::Node* n, const std::unordered_map& fallback_nodes) { + if (fallback_nodes.count(n)) { + if (fallback_nodes.at(n) == 0) { + LOG_GRAPH("Node not supported by conversion: " << util::node_info(n)); + } else if (fallback_nodes.at(n) == 1) { + LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n)); + } else if (fallback_nodes.at(n) == 2) { + LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n)); + } else { + LOG_GRAPH( + "Node fallback to Torch because the NonTensor dependencies with other fallback nodes: " + << util::node_info(n)); + } return false; } @@ -501,12 +253,56 @@ void finalize_block( LOG_DEBUG(g.back()); } -PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info) { +// use this function to get all initial fallback nodes (nodes that are unsupported or forced fallback) +// we use a map to indicate the reason why it's fallback to torch +void get_fallback_nodes( + torch::jit::Block* block, + const std::unordered_set& forced_fallback_ops, + std::unordered_map& fallback_nodes) { + auto nodes = block->nodes(); + for (const auto n : nodes) { + if (n->kind() == torch::jit::prim::Constant) { + continue; + } + + // If the op is not supported by the conversion phase it should run in PyTorch + if (!conversion::OpSupported(n)) { + fallback_nodes.insert({n, 0}); + } + + // If the user specifies the op to run in Torch it should run in PyTorch + if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) { + fallback_nodes.insert({n, 1}); + } + + // If the user specifies the module containing this op to run in torch it should run in PyTorch + const auto to_compile_sym = c10::Symbol::attr("to_compile"); + if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) { + fallback_nodes.insert({n, 2}); + } + } + return; +} + +PartitionedGraph segment_graph( + torch::jit::Block* block, + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes) { auto min_block_size = partition_info.min_block_size; std::unordered_set forced_fallback_ops( partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); + // get the initial fallback nodes (nodes that are unsupported or forced fallback) + get_fallback_nodes(block, forced_fallback_ops, fallback_nodes); + + // For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this + // input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node + // that produces this input should also fallback + // TODO: don't need to fallback the TensorList related nodes once the collection feature is supported + find_all_fallback_nodes(fallback_nodes); + auto nodes = block->nodes(); + PartitionedGraph segmented_blocks; // segment the nodes @@ -517,7 +313,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa continue; } - if (should_run_in_trt(n, forced_fallback_ops)) { + if (check_node_fallback(n, fallback_nodes)) { in_prog_trt_blk_nodes.push_back(n); // If there is an active PyTorch block and we have passed the threshold for a valid TRT @@ -570,7 +366,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa finalize_block(segmented_blocks, SegmentedBlock::kTensorRT, in_prog_trt_blk_nodes); } - if (!in_prog_pyt_blk_nodes.empty()) { + if (!in_prog_pyt_blk_nodes.empty() || !in_prog_trt_blk_nodes.empty()) { in_prog_pyt_blk_nodes.insert( in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); @@ -582,14 +378,17 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa PartitionedGraph Partition( torch::jit::Block* block, std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info) { + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes) { LOG_DEBUG(partition_info); // segment lowering global graph into blocks LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks"); - PartitionedGraph segmented_blocks = segment_graph(block, partition_info); + PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes); + + // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks // resolve nonTensor inputs/outputs - resolveNonTensorInputs(segmented_blocks); + resolveTRTNonTensorInputs(segmented_blocks); // register input/output torch::jit::Value for segmented graphs LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 31ecfebf25..fce88134b7 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -16,12 +16,16 @@ namespace partitioning { typedef std::vector PartitionedGraph; -PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& partition_info); +PartitionedGraph segment_graph( + torch::jit::Block* block, + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes); PartitionedGraph Partition( torch::jit::Block* block, std::unordered_map& example_tensor_map, - const PartitionInfo& partition_info); + const PartitionInfo& partition_info, + std::unordered_map& fallback_nodes); std::ostream& operator<<(std::ostream& os, const PartitionedGraph& g); diff --git a/tests/core/partitioning/BUILD b/tests/core/partitioning/BUILD index 98c549e11d..ec5e9c77fc 100644 --- a/tests/core/partitioning/BUILD +++ b/tests/core/partitioning/BUILD @@ -13,7 +13,8 @@ filegroup( "//tests/modules:mobilenet_v2_traced.jit.pt", "//tests/modules:conditional_scripted.jit.pt", "//tests/modules:loop_fallback_eval_scripted.jit.pt", - "//tests/modules:loop_fallback_no_eval_scripted.jit.pt"] + "//tests/modules:loop_fallback_no_eval_scripted.jit.pt", + "//tests/modules:inplace_op_if_scripted.jit.pt"] ) partitioning_test( diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 9698559f80..e2cbbb549b 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -42,3 +42,34 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { ASSERT_TRUE(conditional_engines_count == 2); } + +TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/inplace_op_if_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return; + } + + const std::vector> input_shapes = {{4, 4}, {4, 4}}; + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + for (auto in_shape : input_shapes) { + auto in = at::randint(5, in_shape, {at::kCUDA}); + jit_inputs_ivalues.push_back(in.clone()); + trt_inputs_ivalues.push_back(in.clone()); + } + + std::vector inputs{torch_tensorrt::core::ir::Input({4, 4}), + torch_tensorrt::core::ir::Input({4, 4})}; + auto g = mod.get_method("forward").graph(); + torch_tensorrt::core::CompileSpec cfg(inputs); + cfg.partition_info.enabled = true; + cfg.partition_info.forced_fallback_operators.push_back("prim::ListConstruct"); + + auto jit_results = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg); + auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results, trt_results, 2e-6)); +} diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index facdd31151..fea202fc65 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -123,8 +123,9 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -181,8 +182,9 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { @@ -255,7 +257,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto fallback_g = new_mod.get_method("forward").graph(); int count = count_trt_engines(fallback_g); - ASSERT_TRUE(count == 2); + ASSERT_TRUE(count == 1); } TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { @@ -372,7 +374,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + std::unordered_map fallback_nodes; + auto segmented_blocks = + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); int torch_block_cnt = 0, trt_block_cnt = 0; for (const auto& segmented_block : segmented_blocks) { diff --git a/tests/core/partitioning/test_segmentation.cpp b/tests/core/partitioning/test_segmentation.cpp index 1a77833577..bf32bcf918 100644 --- a/tests/core/partitioning/test_segmentation.cpp +++ b/tests/core/partitioning/test_segmentation.cpp @@ -74,8 +74,9 @@ TEST(Partitioning, SegmentSequentialModelCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE( @@ -109,8 +110,9 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.min_block_size = 3; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE( @@ -144,8 +146,9 @@ TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE( @@ -179,8 +182,9 @@ TEST(Partitioning, SegmentBranchModelCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 2)); ASSERT_TRUE( @@ -215,8 +219,9 @@ TEST(Partitioning, SegmentBranchModelWithMinBlockSizeCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.min_block_size = 3; + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1)); ASSERT_TRUE( @@ -255,8 +260,9 @@ TEST(Partitioning, SegmentBranchModelWithForcedFallbackOPCorrectly) { torch_tensorrt::core::partitioning::PartitionInfo partition_info; partition_info.enabled = true; partition_info.forced_fallback_operators.push_back("aten::relu"); + std::unordered_map fallback_nodes; torch_tensorrt::core::partitioning::PartitionedGraph segmented_blocks = - torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info); + torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes); ASSERT_TRUE( checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 3)); ASSERT_TRUE( diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 8effa821ae..7bcabc0d51 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -66,8 +66,9 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, @@ -116,8 +117,9 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { input_types.insert({g->inputs()[i], {at::kFloat}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + std::unordered_map fallback_nodes; std::vector segmented_blocks = - torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info); + torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info, fallback_nodes); ASSERT_TRUE(checkSegmentedBlockInputShape( segmented_blocks, diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index 252a3a2b5d..20d501045f 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -87,6 +87,20 @@ def forward(self, x): return x +# Sample Inplace OP in Conditional Block Model +class FallbackInplaceOPIf(nn.Module): + + def __init__(self): + super(FallbackInplaceOPIf, self).__init__() + + def forward(self, x, y): + mod_list = [x] + if x.sum() > y.sum(): + mod_list.append(y) + z = torch.cat(mod_list) + return z + + def BertModule(): model_name = "bert-base-uncased" enc = BertTokenizer.from_pretrained(model_name) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 3ac2de5ac6..1702628f20 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -92,10 +92,18 @@ "model": cm.LoopFallbackEval(), "path": "script" }, + "loop_fallback_no_eval": { + "model": cm.LoopFallbackNoEval(), + "path": "script" + }, "conditional": { "model": cm.FallbackIf(), "path": "script" }, + "inplace_op_if": { + "model": cm.FallbackInplaceOPIf(), + "path": "script" + }, "bert-base-uncased": { "model": cm.BertModule(), "path": "trace"