Skip to content

Use numpy.typing.DTypeLike #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ jobs:
- name: Run pre-commit
uses: pre-commit/[email protected]
- name: Check for Sphinx doc warnings
if: contains(matrix.python-version, '3.8')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ravwojdyla thanks for your (offline) suggestion to do this. The workflow syntax works perfectly. But the build is failing on 3.8 too, so I need to look into that!

run: |
cd docs
make html SPHINXOPTS="-W --keep-going -n"
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -94,7 +94,11 @@ def filter(self, record: pylogging.LogRecord) -> bool:

autosummary_generate = True

nitpick_ignore = [("py:class", "sgkit.display.GenotypeDisplay")]
nitpick_ignore = [
("py:class", "sgkit.display.GenotypeDisplay"),
("py:class", "numpy.typing._dtype_like._DTypeDict"),
("py:class", "numpy.typing._dtype_like._SupportsDType"),
]


# FIXME: Workaround for linking xarray module
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -144,6 +144,9 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-sgkit.*]
allow_redefinition = True
[mypy-sgkit.io.bgen.*]
# avoid warning on unused ignore for Python 3.8, but not unused for 3.7
warn_unused_ignores = False
[mypy-sgkit.*.tests.*]
disallow_untyped_defs = False
disallow_untyped_decorators = False
22 changes: 15 additions & 7 deletions sgkit/io/bgen/bgen_reader.py
Original file line number Diff line number Diff line change
@@ -23,12 +23,20 @@
import xarray as xr
import zarr
from cbgen import bgen_file, bgen_metafile
from numpy.typing import DTypeLike
from rechunker import api as rechunker_api
from xarray import Dataset

from sgkit import create_genotype_dosage_dataset
from sgkit.io.utils import dataframe_to_dict, encode_contigs
from sgkit.typing import ArrayLike, DType, PathType
from sgkit.typing import ArrayLike, PathType

try:
# needed to avoid Sphinx forward reference error for DTypeLike
# try block is needed since SupportsIndex is in Python 3.7
from typing import SupportsIndex # type: ignore # noqa: F401
except ImportError: # pragma: no cover
pass

logger = logging.getLogger(__name__)

@@ -60,7 +68,7 @@ def __init__(
self,
path: PathType,
metafile_path: Optional[PathType] = None,
dtype: DType = "float32",
dtype: DTypeLike = "float32",
) -> None:
self.path = Path(path)
self.metafile_path = (
@@ -202,8 +210,8 @@ def read_bgen(
chunks: Union[str, int, Tuple[int, int, int]] = "auto",
lock: bool = False,
persist: bool = True,
contig_dtype: DType = "str",
gp_dtype: DType = "float32",
contig_dtype: DTypeLike = "str",
gp_dtype: DTypeLike = "float32",
) -> Dataset:
"""Read BGEN dataset.

@@ -394,7 +402,7 @@ def pack_variables(ds: Dataset) -> Dataset:
return ds


def unpack_variables(ds: Dataset, dtype: DType = "float32") -> Dataset:
def unpack_variables(ds: Dataset, dtype: DTypeLike = "float32") -> Dataset:
# Restore homozygous reference GP
gp = ds["call_genotype_probability"].astype(dtype)
if gp.sizes["genotypes"] != 2:
@@ -423,7 +431,7 @@ def rechunk_bgen(
chunk_length: int = 10_000,
chunk_width: int = 1_000,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
probability_dtype: Optional[DTypeLike] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
tempdir: Optional[PathType] = None,
@@ -533,7 +541,7 @@ def bgen_to_zarr(
chunk_width: int = 1_000,
temp_chunk_length: int = 100,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
probability_dtype: Optional[DTypeLike] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
tempdir: Optional[PathType] = None,
7 changes: 4 additions & 3 deletions sgkit/io/utils.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,14 @@
import numpy as np
import xarray as xr
import zarr
from numpy.typing import DTypeLike

from ..typing import ArrayLike, DType
from ..typing import ArrayLike
from ..utils import encode_array, max_str_len


def dataframe_to_dict(
df: dd.DataFrame, dtype: Optional[Mapping[str, DType]] = None
df: dd.DataFrame, dtype: Optional[Mapping[str, DTypeLike]] = None
) -> Mapping[str, ArrayLike]:
""" Convert dask dataframe to dictionary of arrays """
arrs = {}
@@ -110,7 +111,7 @@ def zarrs_to_dataset(
def concatenate_and_rechunk(
zarrs: Sequence[zarr.Array],
chunks: Optional[Tuple[int, ...]] = None,
dtype: DType = None,
dtype: DTypeLike = None,
) -> da.Array:
"""Perform a concatenate and rechunk operation on a collection of Zarr arrays
to produce an array with a uniform chunking, suitable for saving as
5 changes: 3 additions & 2 deletions sgkit/io/vcf/vcf_reader.py
Original file line number Diff line number Diff line change
@@ -20,13 +20,14 @@
import numpy as np
import xarray as xr
from cyvcf2 import VCF, Variant
from numpy.typing import DTypeLike

from sgkit.io.utils import zarrs_to_dataset
from sgkit.io.vcf import partition_into_regions
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size
from sgkit.model import DIM_SAMPLE, DIM_VARIANT, create_genotype_call_dataset
from sgkit.typing import ArrayLike, DType, PathType
from sgkit.typing import ArrayLike, PathType
from sgkit.utils import max_str_len

DEFAULT_MAX_ALT_ALLELES = (
@@ -104,7 +105,7 @@ def _normalize_fields(vcf: VCF, fields: Sequence[str]) -> Sequence[str]:

def _vcf_type_to_numpy_type_and_fill_value(
vcf_type: str, category: str, key: str
) -> Tuple[DType, Any]:
) -> Tuple[DTypeLike, Any]:
"""Convert the VCF Type to a NumPy dtype and fill value."""
if vcf_type == "Flag":
return "bool", False
13 changes: 7 additions & 6 deletions sgkit/stats/ld.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,11 @@
import pandas as pd
from dask.dataframe import DataFrame
from numba import njit
from numpy.typing import DTypeLike
from xarray import Dataset

from sgkit import variables
from sgkit.typing import ArrayLike, DType
from sgkit.typing import ArrayLike
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows


@@ -205,8 +206,8 @@ def _ld_matrix_jit(
chunk_window_stops: ArrayLike,
abs_chunk_start: int,
chunk_max_window_start: int,
index_dtype: DType,
value_dtype: DType,
index_dtype: DTypeLike,
value_dtype: DTypeLike,
threshold: float,
scores: ArrayLike,
) -> List[Any]: # pragma: no cover
@@ -246,7 +247,7 @@ def _ld_matrix_jit(

if no_threshold or (res >= threshold and np.isfinite(res)):
rows.append(
(index_dtype(index), index_dtype(other), value_dtype(res), cmp)
(index_dtype(index), index_dtype(other), value_dtype(res), cmp) # type: ignore
)

return rows
@@ -258,8 +259,8 @@ def _ld_matrix(
chunk_window_stops: ArrayLike,
abs_chunk_start: int,
chunk_max_window_start: int,
index_dtype: DType,
value_dtype: DType,
index_dtype: DTypeLike,
value_dtype: DTypeLike,
threshold: float = np.nan,
scores: Optional[ArrayLike] = None,
) -> ArrayLike:
5 changes: 3 additions & 2 deletions sgkit/stats/pca.py
Original file line number Diff line number Diff line change
@@ -4,14 +4,15 @@
import numpy as np
import xarray as xr
from dask_ml.decomposition import TruncatedSVD
from numpy.typing import DTypeLike
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline
from typing_extensions import Literal
from xarray import DataArray, Dataset

from sgkit import variables

from ..typing import ArrayLike, DType, RandomStateType
from ..typing import ArrayLike, RandomStateType
from ..utils import conditional_merge_datasets
from .aggregation import count_call_alleles
from .preprocessing import PattersonScaler
@@ -331,7 +332,7 @@ def _allele_counts(
ds: Dataset,
variable: str,
check_missing: bool = True,
dtype: DType = "float32",
dtype: DTypeLike = "float32",
) -> DataArray:
if variable not in ds:
ds = count_call_alternate_alleles(ds)
5 changes: 3 additions & 2 deletions sgkit/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -5,18 +5,19 @@
import numpy as np
import pytest
import xarray as xr
from numpy.typing import DTypeLike

import sgkit.stats.preprocessing
from sgkit import simulate_genotype_call_dataset
from sgkit.typing import ArrayLike, DType
from sgkit.typing import ArrayLike


def simulate_alternate_allele_counts(
n_variant: int,
n_sample: int,
ploidy: int,
chunks: Any = (10, 10),
dtype: DType = "i",
dtype: DTypeLike = "i",
seed: int = 0,
) -> ArrayLike:
rs = da.random.RandomState(seed)
3 changes: 1 addition & 2 deletions sgkit/typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from pathlib import Path
from typing import Any, Union
from typing import Union

import dask.array as da
import numpy as np

ArrayLike = Union[np.ndarray, da.Array]
DType = Any
PathType = Union[str, Path]
RandomStateType = Union[np.random.RandomState, da.random.RandomState, int]
5 changes: 3 additions & 2 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,16 @@

import numpy as np
from numba import guvectorize
from numpy.typing import DTypeLike
from xarray import Dataset

from . import variables
from .typing import ArrayLike, DType
from .typing import ArrayLike


def check_array_like(
a: Any,
dtype: Union[None, DType, Set[DType]] = None,
dtype: Union[None, DTypeLike, Set[DTypeLike]] = None,
kind: Union[None, str, Set[str]] = None,
ndim: Union[None, int, Set[int]] = None,
) -> None:
7 changes: 4 additions & 3 deletions sgkit/window.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,13 @@

import dask.array as da
import numpy as np
from numpy.typing import DTypeLike
from xarray import Dataset

from sgkit.utils import conditional_merge_datasets, create_dataset
from sgkit.variables import window_contig, window_start, window_stop

from .typing import ArrayLike, DType
from .typing import ArrayLike

# Window definition (user code)

@@ -110,7 +111,7 @@ def moving_statistic(
statistic: Callable[..., ArrayLike],
size: int,
step: int,
dtype: DType,
dtype: DTypeLike,
**kwargs: Any,
) -> da.Array:
"""A Dask implementation of scikit-allel's moving_statistic function."""
@@ -135,7 +136,7 @@ def window_statistic(
statistic: Callable[..., ArrayLike],
window_starts: ArrayLike,
window_stops: ArrayLike,
dtype: DType,
dtype: DTypeLike,
chunks: Any = None,
new_axis: Union[None, int, Iterable[int]] = None,
**kwargs: Any,