diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 3de2daa14a..0fe56265e7 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -36,13 +36,15 @@ CompileSpec::CompileSpec(torch::jit::IValue input_signature) { graph_inputs.input_signature = input_signature; } -void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) { +void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int depth = 0) { + TORCHTRT_CHECK( + depth <= 2, "Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]") if (input_ivalue.isTuple()) { auto input_tuple = input_ivalue.toTuple(); std::vector converted_elements; for (auto item : input_tuple->elements()) { torch::jit::IValue converted_item; - to_internal_input_signature(item, converted_item); + to_internal_input_signature(item, converted_item, depth++); converted_elements.push_back(converted_item); auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements); converted_ivalue = torch::jit::IValue(tuple_ptr); @@ -53,7 +55,7 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV auto converted_elements = c10::impl::GenericList(type); for (auto item : input_list) { torch::jit::IValue converted_item; - to_internal_input_signature(item, converted_item); + to_internal_input_signature(item, converted_item, depth++); converted_elements.push_back(converted_item); } converted_ivalue = torch::jit::IValue(converted_elements); diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 50ba9b4ed5..0e11d3bcd3 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -194,17 +194,22 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: return info -def _parse_input_signature(input_signature: Any): +def _parse_input_signature(input_signature: Any, depth: int = 0): + if depth > 2: + raise AssertionError( + "Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]" + ) + if isinstance(input_signature, tuple): input_list = [] for item in input_signature: - input = _parse_input_signature(item) + input = _parse_input_signature(item, depth + 1) input_list.append(input) return tuple(input_list) elif isinstance(input_signature, list): input_list = [] for item in input_signature: - input = _parse_input_signature(item) + input = _parse_input_signature(item, depth + 1) input_list.append(input) return input_list elif isinstance(input_signature, Input) or isinstance(