Skip to content

🐛 [Bug] tests/py/dynamo/conversion/test_atan2_aten.py Conversion of function torch._ops.aten.aten::_assert_scalar not currently supported! #3581

Open
@narendasan

Description

@narendasan

Bug Description

a = (<dynamo.conversion.test_atan2_aten.TestAtan2Converter testMethod=test_dynamic_shape_atan2_2_3d_dim_dtype_int32>,), kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

../.venv/lib/python3.9/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
conversion/test_atan2_aten.py:182: in test_dynamic_shape_atan2
    self.run_test_with_dynamic_shape(
conversion/harness.py:598: in run_test_with_dynamic_shape
    super().run_test(mod, inputs_max, interp, rtol, atol, pyt_inputs=pyt_inputs)
conversion/harness.py:202: in run_test
    interpreter_result = interpreter.run()
../../../py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:719: in run
    self._construct_trt_network_def()
../../../py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:395: in _construct_trt_network_def
    super().run()
../.venv/lib/python3.9/site-packages/torch/fx/interpreter.py:171: in run
    self.env[node] = self.run_node(node)
../../../py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:779: in run_node
    trt_node: torch.fx.Node = super().run_node(n)
../.venv/lib/python3.9/site-packages/torch/fx/interpreter.py:240: in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torch_tensorrt.dynamo.conversion._TRTInterpreter.TRTInterpreter object at 0x7b0a4b7e76d0>, target = <OpOverload(op='aten._assert_scalar', overload='default')>, args = (<tensorrt_bindings.tensorrt.ITensor object at 0x7b0a4ba236f0>, "Runtime assertion failed for expression Eq(s90, s79) on node 'eq'"), kwargs = {}

    def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
        # TODO: Why is this stateful? We should be able to take in the inputs
        converter_packet = CONVERTERS.get(self._cur_node)
        if converter_packet is None:
>           raise UnsupportedOperatorException(
                f"Conversion of function {torch.typename(target)} not currently supported!"
            )
E           torch_tensorrt.dynamo.conversion._TRTInterpreter.UnsupportedOperatorException: Conversion of function torch._ops.aten.aten::_assert_scalar not currently supported!
E           
E           While executing %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s90, s79) on node 'eq'), kwargs = {})
E           GraphModule: class GraphModule(torch.nn.Module):
E               def forward(self, lhs_val, rhs_val):
E                   lhs_val: "i32[s35, s16, s79][s16*s79, s79, 1]"; rhs_val: "i32[s58, s16, s79][s16*s79, s79, 1]"; 
E               
E                   lhs_val, rhs_val, = fx_pytree.tree_flatten_spec(([lhs_val, rhs_val], {}), self._in_spec)
E                    # 
E                   sym_size_int_1: "Sym(s16)" = torch.ops.aten.sym_size.int(lhs_val, 1)
E                   sym_size_int_2: "Sym(s79)" = torch.ops.aten.sym_size.int(lhs_val, 2)
E                   sym_size_int_4: "Sym(s16)" = torch.ops.aten.sym_size.int(rhs_val, 1)
E                   sym_size_int_5: "Sym(s79)" = torch.ops.aten.sym_size.int(rhs_val, 2)
E                   
E                    # File: /home/naren/pytorch_org/tensorrt/tests/py/dynamo/conversion/test_atan2_aten.py:166 in forward, code: return torch.ops.aten.atan2.default(lhs_val, rhs_val)
E                   atan2: "f32[s35, s16, s79][s16*s79, s79, 1]" = torch.ops.aten.atan2.default(lhs_val, rhs_val);  lhs_val = rhs_val = None
E                   
E                    # 
E                   eq: "Sym(True)" = sym_size_int_2 == sym_size_int_5;  sym_size_int_2 = sym_size_int_5 = None
E                   _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s90, s79) on node 'eq'");  eq = _assert_scalar_default = None
E                   eq_1: "Sym(True)" = sym_size_int_1 == sym_size_int_4;  sym_size_int_1 = sym_size_int_4 = None
E                   _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s16, s43) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
E                   return pytree.tree_unflatten((atan2,), self._out_spec)
E                   
E           
E           Original traceback:
E           File "torch/fx/passes/runtime_assert.py", line 24, in insert_deferred_runtime_asserts
E           
E           To execute this test, run the following from the base repo dir:
E               python test_atan2_aten.py TestAtan2Converter.test_dynamic_shape_atan2_2_3d_dim_dtype_int32
E           
E           This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

../../../py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:874: UnsupportedOperatorException
--------------------------------------------------------------------------------------------------------------------------------------------------------------- Captured log call ---------------------------------------------------------------------------------------------------------------------------------------------------------------
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymInt'>
WARNING  torch_tensorrt.dynamo.conversion.converter_utils:converter_utils.py:72 Detected unparsable type in node formatting: <class 'torch.SymBool'>

To Reproduce

Steps to reproduce the behavior:

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions