diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index 82c3f07801..ed448b1bbc 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -20,7 +20,7 @@ struct LowerInfo { std::vector forced_fallback_modules; friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l); - std::string getGPUDeviceString() { + std::string getGPUDeviceString() const { return "cuda:" + std::to_string(target_device.gpu_id); }; }; diff --git a/core/partitioning/partitioninginfo/PartitioningInfo.h b/core/partitioning/partitioninginfo/PartitioningInfo.h index 8eb052e0fa..ed7d2033c6 100644 --- a/core/partitioning/partitioninginfo/PartitioningInfo.h +++ b/core/partitioning/partitioninginfo/PartitioningInfo.h @@ -16,6 +16,11 @@ struct PartitioningInfo { uint64_t min_block_size = 1; std::vector forced_fallback_operators; bool truncate_long_and_double; + ir::Device target_device; + + std::string getGPUDeviceString() const { + return "cuda:" + std::to_string(target_device.gpu_id); + }; }; std::ostream& operator<<(std::ostream& os, const PartitioningInfo& s); diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 81220e3af8..4220764dd6 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -99,7 +99,7 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) { return nullptr; } -torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) { +torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) { auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index]; auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index]; torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value); @@ -125,8 +125,11 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3); auto const_zero = g->insertConstant(0); const_zero->setType(torch::jit::BoolType::get()); + auto cuda = g->insertConstant(device); + cuda->setType(torch::jit::DeviceObjType::get()); auto none_val = g->insertNode(g->createNone())->output(); - cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val}); + cast_node = + g->create(torch::jit::aten::to, {cast_subgraph_value, cuda, const_type, const_zero, const_zero, none_val}); } return cast_node; } @@ -217,6 +220,8 @@ void getSegmentsOutputByRunning( ivalues_maps[output] = jit_results[idx++]; } + auto target_device = partitioning_info.getGPUDeviceString(); + // auto int64 <=> int32 conversion if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) { // First, check if there is Int64 input @@ -226,7 +231,7 @@ void getSegmentsOutputByRunning( at::ScalarType t = cur_ivalue.toTensor().scalar_type(); if (t == at::kLong) { // we add a cast operation to cast the type to Int64 - auto cast_node = createCastNode(seg_block, i, true); + auto cast_node = createCastNode(seg_block, i, true, target_device); seg_block.g()->prependNode(cast_node); seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); } @@ -237,7 +242,7 @@ void getSegmentsOutputByRunning( auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]]; at::ScalarType t = cur_ivalue.toTensor().scalar_type(); if (t == at::kLong) { - auto cast_node = createCastNode(seg_block, i, false); + auto cast_node = createCastNode(seg_block, i, false, target_device); seg_block.g()->appendNode(cast_node); seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); } diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 40a7ae5c35..24aba31515 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -111,6 +111,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double; internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback; internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; + internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback; TORCHTRT_CHECK( !(external.require_full_compilation && (external.torch_executed_ops.size() > 0)), @@ -132,11 +133,13 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { case Device::DeviceType::kDLA: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kDLA; internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kDLA; + internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kDLA; break; case Device::DeviceType::kGPU: default: internal.convert_info.engine_settings.device.device_type = nvinfer1::DeviceType::kGPU; internal.lower_info.target_device.device_type = nvinfer1::DeviceType::kGPU; + internal.partitioning_info.target_device.device_type = nvinfer1::DeviceType::kGPU; } switch (external.capability) { @@ -155,6 +158,9 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { internal.convert_info.engine_settings.device.dla_core = external.device.dla_core; internal.lower_info.target_device.gpu_id = external.device.gpu_id; internal.lower_info.target_device.dla_core = external.device.dla_core; + internal.partitioning_info.target_device.gpu_id = external.device.gpu_id; + internal.partitioning_info.target_device.dla_core = external.device.dla_core; + internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters; internal.convert_info.engine_settings.workspace_size = external.workspace_size; internal.convert_info.engine_settings.dla_sram_size = external.dla_sram_size;