6
6
# mypy: disable-error-code=misc
7
7
8
8
"""Unit tests for the onnx_types module."""
9
+ from __future__ import annotations
9
10
10
11
import unittest
11
12
@@ -24,13 +25,13 @@ def test_instantiation(self):
24
25
FLOAT [...]()
25
26
26
27
@parameterized .expand (tensor_type_registry .items ())
27
- def test_type_properties (self , dtype : DType , tensor_type : TensorType ):
28
+ def test_type_properties (self , dtype : DType , tensor_type : type [ TensorType ] ):
28
29
self .assertEqual (tensor_type .dtype , dtype )
29
30
self .assertIsNone (tensor_type .shape )
30
- self .assertEqual (tensor_type [...].shape , ...)
31
- self .assertEqual (tensor_type [...].dtype , dtype )
32
- self .assertEqual (tensor_type [1 , 2 , 3 ].shape , (1 , 2 , 3 ))
33
- self .assertEqual (tensor_type [1 , 2 , 3 ].dtype , dtype )
31
+ self .assertEqual (tensor_type [...].shape , ...) # type: ignore[index]
32
+ self .assertEqual (tensor_type [...].dtype , dtype ) # type: ignore[index]
33
+ self .assertEqual (tensor_type [1 , 2 , 3 ].shape , (1 , 2 , 3 )) # type: ignore[index]
34
+ self .assertEqual (tensor_type [1 , 2 , 3 ].dtype , dtype ) # type: ignore[index]
34
35
35
36
@parameterized .expand ([(dtype ,) for dtype in tensor_type_registry ])
36
37
def test_dtype_bound_to_subclass (self , dtype : DType ):
0 commit comments