diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f4618f6b3..146fb4aa2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -54,8 +54,10 @@ jobs: pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}" - name: Install dependencies run: | + + pip3 install requests mypy==0.990 graphviz numpy pip3 install -r requirements.txt - pip3 install mypy==0.960 numpy types-requests + - name: Build TorchData run: | pip3 install . @@ -66,7 +68,7 @@ jobs: run: | set -eux STATUS= - if ! mypy --config=mypy.ini; then + if ! mypy --config=mypy.ini --enable-recursive-aliases --install-types --non-interactive; then STATUS=fail fi if [ -n "$STATUS" ]; then diff --git a/torchdata/datapipes/iter/util/randomsplitter.py b/torchdata/datapipes/iter/util/randomsplitter.py index 2608f93c8..c2cfd3027 100644 --- a/torchdata/datapipes/iter/util/randomsplitter.py +++ b/torchdata/datapipes/iter/util/randomsplitter.py @@ -101,7 +101,7 @@ def __init__( self._rng = random.Random(self._seed) self._lengths: List[int] = [] - def draw(self) -> T: + def draw(self) -> T: # type: ignore selected_key = self._rng.choices(self.keys, self.weights)[0] index = self.key_to_index[selected_key] self.weights[index] -= 1 diff --git a/torchdata/datapipes/iter/util/tfrecordloader.py b/torchdata/datapipes/iter/util/tfrecordloader.py index 8ad5ea101..914ab5ab6 100644 --- a/torchdata/datapipes/iter/util/tfrecordloader.py +++ b/torchdata/datapipes/iter/util/tfrecordloader.py @@ -45,7 +45,7 @@ def prod(xs): # type: ignore[no-redef] # Note, reccursive types not supported by mypy at the moment # TODO(640): uncomment as soon as it becomes supported # https://github.com/python/mypy/issues/731 -# BinaryData = Union[str, List['BinaryData']] +BinaryData = Union[str, List["BinaryData"]] TFRecordBinaryData = Union[str, List[str], List[List[str]], List[List[List[Any]]]] TFRecordExampleFeature = Union[torch.Tensor, List[torch.Tensor], TFRecordBinaryData] TFRecordExample = Dict[str, TFRecordExampleFeature]