diff --git a/core/compiler.cpp b/core/compiler.cpp index 92809affc8..3dd735a59e 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph( torch::jit::Block* block, CompileSpec cfg, ir::StaticParams static_params, - ir::CollectionTypeMap first_use_types) { + ir::CollectionTypeMap first_use_types, + bool expect_full_compilation = false) { auto convert_info = cfg.convert_info; auto partitioning_info = cfg.partitioning_info; @@ -149,10 +150,12 @@ partitioning::GraphAndMapping BuildHybridGraph( // TODO: Combine this within partition call partitioning::populateInputIValues(&partitioning_ctx); - partitioning::partition(&partitioning_ctx); + partitioning::partition(&partitioning_ctx, expect_full_compilation); for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) { partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second; + int num_torch_segments = 0; + int num_trt_segments = 0; for (auto& seg_block : segmented_blocks) { LOG_INFO("Block segment:" << seg_block); @@ -160,6 +163,7 @@ partitioning::GraphAndMapping BuildHybridGraph( trt_engine_id << reinterpret_cast(&seg_block); if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { + num_trt_segments++; auto inputs = seg_block.construct_inputs_spec(); // update the input ranges for each segments convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); @@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph( true); seg_block.update_graph(temp_g); + } else { + num_torch_segments++; + + // If full compilation is expected, ensure that all operators in Torch blocks are + // for collections processing + if (expect_full_compilation) { + for (auto torch_node : seg_block.block()->nodes()) { + if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) { + TORCHTRT_THROW_ERROR( + "Full compilation specified but node " + << *torch_node + << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules." + << " Try recompiling with require_full_compilation=False."); + } + } + } } } + + // If full compilation is expected, cannot have more than 2 Torch segments + // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment + if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) { + TORCHTRT_THROW_ERROR( + "Full compilation was requested but unable to convert all operations to TensorRT." + << " Try recompiling with require_full_compilation=False."); + } } return partitioning::stitch(&partitioning_ctx, block); @@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes( CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, - ir::CollectionTypeMap& first_use_type_map) { + ir::CollectionTypeMap& first_use_type_map, + bool requires_collection_handling = false) { cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params)); cfg.partitioning_info.collection_input_spec_map = @@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes( "Cannot infer input type from calcuations in graph for input " << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); spec[i].dtype = at::kFloat; - } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) { + } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) { if (!est_type_opt[i]) { LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); std::stringstream ss; @@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: return engine; } +bool userRequestedFallback(CompileSpec& cfg) { + return cfg.lower_info.forced_fallback_modules.size() != 0 || + cfg.partitioning_info.forced_fallback_operators.size() != 0; +} + torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) { torch::jit::Module new_mod(mod._ivalue()->name() + "_trt"); @@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) // Infer the type of an input from the weights of the calculation auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); + // Determine if the block is convertible/has collection output, and based on the result, + // whether full compilation can be expected + auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); + auto outputIsCollection = conversion::OutputIsCollection(g->block()); + auto requires_collection_handling = (isBlockConvertible && outputIsCollection); + + // Determine whether user specifications necessitate partitioning + auto isFallbackRequested = userRequestedFallback(cfg); + // Extract map of IValue to DType - auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); + auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling); // Check whether any of the input types are Long bool user_requested_long = false; @@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) user_requested_long &= (casts_inserted > 0); } - auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); - auto outputIsCollection = conversion::OutputIsCollection(g->block()); - if (cfg.partitioning_info.enabled && !user_requested_long && - (cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && - !outputIsCollection) { + // Partitioning is required if: + // 1. User requested some modules/operators fallback + // 2. The block (graph) cannot be converted due to operator coverage + // 3. The output of the graph is a collection + // 4. The user requested a non-TRT data type input + auto isPartitioningRequired = + (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long); + + // The user did not require full compilation, but the model can be fully compiled + if (cfg.partitioning_info.enabled && !isPartitioningRequired) { LOG_INFO("Skipping partitioning since model is fully supported"); } - if (cfg.partitioning_info.enabled && - (!(cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || - outputIsCollection || user_requested_long)) { - auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); + // The user did not require full compilation, and the model can be fully compiled + // or, the user required full compilation but the I/O of the graph use collections + if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) { + // If the model is fully-compilable and the user has specified full compilation, run partitioning + // to generate collection-processing code in Torch + auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled); + + auto graph_and_mapping = + BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation); new_g = graph_and_mapping.first; // renaming the input name of graph after fallback to ensure pytorch deserialize it correctly for (size_t i = 0; i < new_g->inputs().size(); ++i) { diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index b1406446f1..931209d636 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -32,7 +32,7 @@ int AutocastLongInputs( std::string target_device_name) { int num_autocasts = 0; // For each graph input, determine if it can be autocasted - for (int i = 0; i < g->inputs().size(); i++) { + for (size_t i = 0; i < g->inputs().size(); i++) { auto input = g->inputs()[i]; // Autocasted inputs must be Tensor-type diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 4d74461454..7b764b04fb 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) { } } -void partition(PartitioningCtx* ctx) { +void partition(PartitioningCtx* ctx, bool expect_full_compilation) { + // If full compilation is expected, overwrite minimum block size + // Any nonzero block size is valid if full compilation to TRT is desired + // Override the default min_block_size to ensure all TRT-supported operations are + // executed in TRT, regardless of the size of the graph + if (expect_full_compilation) { + // If minimum block size is different from the default, the user must have specified it + if (ctx->settings.min_block_size != 3) { + LOG_WARNING( + "Detected user-specified min_block_size with require_full_compilation=True " + << "disregarding min_block_size."); + } + ctx->settings.min_block_size = 1; + } + LOG_DEBUG(ctx->settings); // Go through all the blocks to do the partitioning diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 7c72d091b6..3315ffa210 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -18,6 +18,19 @@ typedef std::unordered_map Example typedef std::pair, std::unordered_map> GraphAndMapping; +// Set of schemas allowed to be executed in Torch, even with require_full_compilation=true, +// as necessary for returning collections of Tensors or other complex constructs, and for +// processing inputs to TRT engines +const std::unordered_set CollectionNodeKinds = { + c10::Symbol::fromQualString("prim::Constant"), + c10::Symbol::fromQualString("aten::__getitem__"), + c10::Symbol::fromQualString("prim::ListConstruct"), + c10::Symbol::fromQualString("prim::ListUnpack"), + c10::Symbol::fromQualString("prim::TupleIndex"), + c10::Symbol::fromQualString("prim::TupleConstruct"), + c10::Symbol::fromQualString("prim::TupleUnpack"), +}; + ExampleIValues generateRandomInputs( ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types, @@ -35,7 +48,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block); GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block); -void partition(PartitioningCtx* ctx); +void partition(PartitioningCtx* ctx, bool expect_full_compilation = false); } // namespace partitioning } // namespace core diff --git a/tests/py/api/test_e2e_behavior.py b/tests/py/api/test_e2e_behavior.py index 385fe916f4..499106e9ca 100644 --- a/tests/py/api/test_e2e_behavior.py +++ b/tests/py/api/test_e2e_behavior.py @@ -4,6 +4,7 @@ import torchvision.models as models import copy from typing import Dict +from utils import same_output_format class TestInputTypeDefaultsFP32Model(unittest.TestCase): @@ -109,6 +110,73 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self): ) trt_mod(self.input) + def test_nested_combination_tuple_list_output_with_full_compilation(self): + class Sample(torch.nn.Module): + def __init__(self): + super(Sample, self).__init__() + + def forward(self, x, y, z): + c = 1.0 + b = x + 2.0 * z + b = y + b + a = b + c + return (a, [b, c]) + + self.model = Sample().eval().to("cuda") + self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0") + self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0") + self.input_3 = torch.ones((5, 5), dtype=torch.float, device="cuda:0") + scripted_mod = torch.jit.script(self.model) + + inputs = [ + torchtrt.Input((5, 5), dtype=torch.float), + torchtrt.Input((5, 5), dtype=torch.float), + torchtrt.Input((5, 5), dtype=torch.float), + ] + + trt_mod = torchtrt.ts.compile( + scripted_mod, + inputs=inputs, + require_full_compilation=True, + enabled_precisions={torch.float, torch.half}, + ) + trt_output = trt_mod(self.input_1, self.input_2, self.input_3) + torch_output = self.model(self.input_1, self.input_2, self.input_3) + assert same_output_format( + trt_output, torch_output + ), "Found differing output formatting between Torch-TRT and Torch" + + def test_tuple_output_with_full_compilation(self): + class Sample(torch.nn.Module): + def __init__(self): + super(Sample, self).__init__() + + def forward(self, x, y): + a = x + y + return (a,) + + self.model = Sample().eval().to("cuda") + self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0") + self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0") + scripted_mod = torch.jit.script(self.model) + + inputs = [ + torchtrt.Input((5, 5), dtype=torch.float), + torchtrt.Input((5, 5), dtype=torch.float), + ] + + trt_mod = torchtrt.ts.compile( + scripted_mod, + inputs=inputs, + require_full_compilation=True, + enabled_precisions={torch.float, torch.half}, + ) + trt_output = trt_mod(self.input_1, self.input_2) + torch_output = self.model(self.input_1, self.input_2) + assert same_output_format( + trt_output, torch_output + ), "Found differing output formatting between Torch-TRT and Torch" + if __name__ == "__main__": unittest.main() diff --git a/tests/py/api/utils.py b/tests/py/api/utils.py index b1e6632ec3..ff6bc39158 100644 --- a/tests/py/api/utils.py +++ b/tests/py/api/utils.py @@ -13,3 +13,42 @@ def cosine_similarity(gt_tensor, pred_tensor): res = res.cpu().detach().item() return res + + +def same_output_format(trt_output, torch_output): + # For each encountered collection type, ensure the torch and trt outputs agree + # on type and size, checking recursively through all member elements. + if isinstance(trt_output, tuple): + return ( + isinstance(torch_output, tuple) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, list): + return ( + isinstance(torch_output, list) + and (len(trt_output) == len(torch_output)) + and all( + same_output_format(trt_entry, torch_entry) + for trt_entry, torch_entry in zip(trt_output, torch_output) + ) + ) + elif isinstance(trt_output, dict): + return ( + isinstance(torch_output, dict) + and (len(trt_output) == len(torch_output)) + and (trt_output.keys() == torch_output.keys()) + and all( + same_output_format(trt_output[key], torch_output[key]) + for key in trt_output.keys() + ) + ) + elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + raise AssertionError( + "Unsupported output type 'set' encountered in output format check." + ) + else: + return type(trt_output) is type(torch_output)