From ae8171cafc35e25ec0050b8bf295374e9388a7d8 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Aug 2020 10:53:04 +0100 Subject: [PATCH 1/4] Count allele functions should return datasets not arrays. --- sgkit/stats/aggregation.py | 50 +++++++++++++++++++-------------- sgkit/tests/test_aggregation.py | 36 ++++++++++++++++-------- 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 15a1a16cf..388c3f309 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -1,8 +1,7 @@ import dask.array as da import numpy as np -import xarray as xr from numba import guvectorize -from xarray import DataArray, Dataset +from xarray import Dataset from ..typing import ArrayLike @@ -45,7 +44,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: out[a] += 1 -def count_call_alleles(ds: Dataset) -> DataArray: +def count_call_alleles(ds: Dataset) -> Dataset: """Compute per sample allele counts from genotype calls. Parameters @@ -56,10 +55,10 @@ def count_call_alleles(ds: Dataset) -> DataArray: Returns ------- - call_allele_count : DataArray - Allele counts with shape (variants, samples, alleles) and values - corresponding to the number of non-missing occurrences - of each allele. + Dataset + Array `call_allele_count` of allele counts with + shape (variants, samples, alleles) and values corresponding to + the number of non-missing occurrences of each allele. Examples -------- @@ -75,7 +74,7 @@ def count_call_alleles(ds: Dataset) -> DataArray: 2 0/1 1/0 3 0/0 0/0 - >>> sg.count_call_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE + >>> sg.count_call_alleles(ds)["call_allele_count"].values # doctest: +NORMALIZE_WHITESPACE array([[[1, 1], [1, 1]], @@ -92,14 +91,19 @@ def count_call_alleles(ds: Dataset) -> DataArray: G = da.asarray(ds["call_genotype"]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) - return xr.DataArray( - da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2), - dims=("variants", "samples", "alleles"), - name="call_allele_count", + return Dataset( + { + "call_allele_count": ( + ("variants", "samples", "alleles"), + da.map_blocks( + count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2 + ), + ) + } ) -def count_variant_alleles(ds: Dataset) -> DataArray: +def count_variant_alleles(ds: Dataset) -> Dataset: """Compute allele count from genotype calls. Parameters @@ -110,10 +114,10 @@ def count_variant_alleles(ds: Dataset) -> DataArray: Returns ------- - variant_allele_count : DataArray - Allele counts with shape (variants, alleles) and values - corresponding to the number of non-missing occurrences - of each allele. + Dataset + Array `variant_allele_count` of allele counts with + shape (variants, alleles) and values corresponding to + the number of non-missing occurrences of each allele. Examples -------- @@ -129,13 +133,17 @@ def count_variant_alleles(ds: Dataset) -> DataArray: 2 0/1 1/0 3 0/0 0/0 - >>> sg.count_variant_alleles(ds).values # doctest: +NORMALIZE_WHITESPACE + >>> sg.count_variant_alleles(ds)["variant_allele_count"].values # doctest: +NORMALIZE_WHITESPACE array([[2, 2], [1, 3], [2, 2], [4, 0]], dtype=uint64) """ - return xr.DataArray( - count_call_alleles(ds).sum(dim="samples").rename("variant_allele_count"), - dims=("variants", "alleles"), + return Dataset( + { + "variant_allele_count": ( + ("variants", "alleles"), + count_call_alleles(ds)["call_allele_count"].sum(dim="samples"), + ) + } ) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index e1e2ad5a5..1cdebb9bb 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -21,22 +21,25 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset: def test_count_variant_alleles__single_variant_single_sample(): - ac = count_variant_alleles(get_dataset([[[1, 0]]])) + ds = count_variant_alleles(get_dataset([[[1, 0]]])) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1]])) def test_count_variant_alleles__multi_variant_single_sample(): - ac = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) + ds = count_variant_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[2, 0], [1, 1], [1, 1], [0, 2]])) def test_count_variant_alleles__single_variant_multi_sample(): - ac = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) + ds = count_variant_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[4, 4]])) def test_count_variant_alleles__multi_variant_multi_sample(): - ac = count_variant_alleles( + ds = count_variant_alleles( get_dataset( [ [[0, 0], [0, 0], [0, 0]], @@ -46,11 +49,12 @@ def test_count_variant_alleles__multi_variant_multi_sample(): ] ) ) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[6, 0], [5, 1], [2, 4], [0, 6]])) def test_count_variant_alleles__missing_data(): - ac = count_variant_alleles( + ds = count_variant_alleles( get_dataset( [ [[-1, -1], [-1, -1], [-1, -1]], @@ -60,11 +64,12 @@ def test_count_variant_alleles__missing_data(): ] ) ) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[0, 0], [2, 1], [1, 2], [0, 6]])) def test_count_variant_alleles__higher_ploidy(): - ac = count_variant_alleles( + ds = count_variant_alleles( get_dataset( [ [[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]], @@ -74,6 +79,7 @@ def test_count_variant_alleles__higher_ploidy(): n_ploidy=3, ) ) + ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1, 1, 0], [1, 2, 2, 1]])) @@ -89,22 +95,25 @@ def test_count_variant_alleles__chunked(): def test_count_call_alleles__single_variant_single_sample(): - ac = count_call_alleles(get_dataset([[[1, 0]]])) + ds = count_call_alleles(get_dataset([[[1, 0]]])) + ac = ds["call_allele_count"] np.testing.assert_equal(ac, np.array([[[1, 1]]])) def test_count_call_alleles__multi_variant_single_sample(): - ac = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) + ds = count_call_alleles(get_dataset([[[0, 0]], [[0, 1]], [[1, 0]], [[1, 1]]])) + ac = ds["call_allele_count"] np.testing.assert_equal(ac, np.array([[[2, 0]], [[1, 1]], [[1, 1]], [[0, 2]]])) def test_count_call_alleles__single_variant_multi_sample(): - ac = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) + ds = count_call_alleles(get_dataset([[[0, 0], [1, 0], [0, 1], [1, 1]]])) + ac = ds["call_allele_count"] np.testing.assert_equal(ac, np.array([[[2, 0], [1, 1], [1, 1], [0, 2]]])) def test_count_call_alleles__multi_variant_multi_sample(): - ac = count_call_alleles( + ds = count_call_alleles( get_dataset( [ [[0, 0], [0, 0], [0, 0]], @@ -114,6 +123,7 @@ def test_count_call_alleles__multi_variant_multi_sample(): ] ) ) + ac = ds["call_allele_count"] np.testing.assert_equal( ac, np.array( @@ -128,7 +138,7 @@ def test_count_call_alleles__multi_variant_multi_sample(): def test_count_call_alleles__missing_data(): - ac = count_call_alleles( + ds = count_call_alleles( get_dataset( [ [[-1, -1], [-1, -1], [-1, -1]], @@ -138,6 +148,7 @@ def test_count_call_alleles__missing_data(): ] ) ) + ac = ds["call_allele_count"] np.testing.assert_equal( ac, np.array( @@ -152,7 +163,7 @@ def test_count_call_alleles__missing_data(): def test_count_call_alleles__higher_ploidy(): - ac = count_call_alleles( + ds = count_call_alleles( get_dataset( [ [[-1, -1, 0], [-1, -1, 1], [-1, -1, 2]], @@ -162,6 +173,7 @@ def test_count_call_alleles__higher_ploidy(): n_ploidy=3, ) ) + ac = ds["call_allele_count"] np.testing.assert_equal( ac, np.array( From fdd7b620b6f8ab9eef52756ab9bc5a92fa0cf519 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Aug 2020 11:19:36 +0100 Subject: [PATCH 2/4] Add merge=True to count allele functions. --- sgkit/stats/aggregation.py | 16 ++++++++++++---- sgkit/tests/test_aggregation.py | 8 ++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 388c3f309..353ec6c49 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -44,7 +44,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None: out[a] += 1 -def count_call_alleles(ds: Dataset) -> Dataset: +def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: """Compute per sample allele counts from genotype calls. Parameters @@ -52,6 +52,9 @@ def count_call_alleles(ds: Dataset) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. + merge : bool + If True, merge the input dataset and the computed variables into + a single dataset, otherwise return only the computed variables. Returns ------- @@ -91,7 +94,7 @@ def count_call_alleles(ds: Dataset) -> Dataset: G = da.asarray(ds["call_genotype"]) shape = (G.chunks[0], G.chunks[1], n_alleles) N = da.empty(n_alleles, dtype=np.uint8) - return Dataset( + new_ds = Dataset( { "call_allele_count": ( ("variants", "samples", "alleles"), @@ -101,9 +104,10 @@ def count_call_alleles(ds: Dataset) -> Dataset: ) } ) + return ds.merge(new_ds) if merge else new_ds -def count_variant_alleles(ds: Dataset) -> Dataset: +def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: """Compute allele count from genotype calls. Parameters @@ -111,6 +115,9 @@ def count_variant_alleles(ds: Dataset) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. + merge : bool + If True, merge the input dataset and the computed variables into + a single dataset, otherwise return only the computed variables. Returns ------- @@ -139,7 +146,7 @@ def count_variant_alleles(ds: Dataset) -> Dataset: [2, 2], [4, 0]], dtype=uint64) """ - return Dataset( + new_ds = Dataset( { "variant_allele_count": ( ("variants", "alleles"), @@ -147,3 +154,4 @@ def count_variant_alleles(ds: Dataset) -> Dataset: ) } ) + return ds.merge(new_ds) if merge else new_ds diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 1cdebb9bb..291b81aca 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -22,6 +22,7 @@ def get_dataset(calls: ArrayLike, **kwargs: Any) -> Dataset: def test_count_variant_alleles__single_variant_single_sample(): ds = count_variant_alleles(get_dataset([[[1, 0]]])) + assert "call_genotype" in ds ac = ds["variant_allele_count"] np.testing.assert_equal(ac, np.array([[1, 1]])) @@ -94,6 +95,13 @@ def test_count_variant_alleles__chunked(): xr.testing.assert_equal(ac1, ac2) # type: ignore[no-untyped-call] +def test_count_variant_alleles__no_merge(): + ds = count_variant_alleles(get_dataset([[[1, 0]]]), merge=False) + assert "call_genotype" not in ds + ac = ds["variant_allele_count"] + np.testing.assert_equal(ac, np.array([[1, 1]])) + + def test_count_call_alleles__single_variant_single_sample(): ds = count_call_alleles(get_dataset([[[1, 0]]])) ac = ds["call_allele_count"] From be502ec06e8bff29251ddda163f70a4f8114e88c Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 1 Sep 2020 11:58:02 +0100 Subject: [PATCH 3/4] Issue a MergeWarning in the case when input variables are overwritten. --- setup.cfg | 2 ++ sgkit/stats/aggregation.py | 25 ++++++++++++++++--------- sgkit/tests/test_utils.py | 28 +++++++++++++++++++++++++++- sgkit/utils.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index c52259849..af85ef741 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ fail_under = 100 [tool:pytest] addopts = --doctest-modules --ignore=validation norecursedirs = .eggs docs +filterwarnings = + error [flake8] ignore = diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 353ec6c49..bd08fa656 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -3,7 +3,8 @@ from numba import guvectorize from xarray import Dataset -from ..typing import ArrayLike +from sgkit.typing import ArrayLike +from sgkit.utils import merge_datasets @guvectorize( # type: ignore @@ -52,9 +53,12 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. - merge : bool - If True, merge the input dataset and the computed variables into - a single dataset, otherwise return only the computed variables. + merge : bool, optional + If True (the default), merge the input dataset and the computed + output variables into a single dataset. Output variables will + overwrite any input variables with the same name, and a warning + will be issued in this case. + If False, return only the computed output variables. Returns ------- @@ -104,7 +108,7 @@ def count_call_alleles(ds: Dataset, merge: bool = True) -> Dataset: ) } ) - return ds.merge(new_ds) if merge else new_ds + return merge_datasets(ds, new_ds) if merge else new_ds def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: @@ -115,9 +119,12 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: ds : Dataset Genotype call dataset such as from `sgkit.create_genotype_call_dataset`. - merge : bool - If True, merge the input dataset and the computed variables into - a single dataset, otherwise return only the computed variables. + merge : bool, optional + If True (the default), merge the input dataset and the computed + output variables into a single dataset. Output variables will + overwrite any input variables with the same name, and a warning + will be issued in this case. + If False, return only the computed output variables. Returns ------- @@ -154,4 +161,4 @@ def count_variant_alleles(ds: Dataset, merge: bool = True) -> Dataset: ) } ) - return ds.merge(new_ds) if merge else new_ds + return merge_datasets(ds, new_ds) if merge else new_ds diff --git a/sgkit/tests/test_utils.py b/sgkit/tests/test_utils.py index a73c79684..c2d2db124 100644 --- a/sgkit/tests/test_utils.py +++ b/sgkit/tests/test_utils.py @@ -1,11 +1,19 @@ from typing import Any, List +import dask.array as da import numpy as np import pytest +import xarray as xr from hypothesis import given, settings from hypothesis import strategies as st -from sgkit.utils import check_array_like, encode_array, split_array_chunks +from sgkit.utils import ( + MergeWarning, + check_array_like, + encode_array, + merge_datasets, + split_array_chunks, +) def test_check_array_like(): @@ -66,6 +74,24 @@ def test_encode_array( np.testing.assert_equal(n, expected_names) +def test_merge_datasets(): + ds = xr.Dataset(dict(x=xr.DataArray(da.zeros(100)))) + + new_ds1 = xr.Dataset( + dict(y=xr.DataArray(da.zeros(100)), z=xr.DataArray(da.zeros(100))) + ) + new_ds2 = xr.Dataset( + dict(y=xr.DataArray(da.ones(100)), z=xr.DataArray(da.zeros(100))) + ) + + ds = merge_datasets(ds, new_ds1) + assert "y" in ds + + with pytest.warns(MergeWarning): + ds = merge_datasets(ds, new_ds2) + np.testing.assert_equal(ds["y"].values, np.ones(100)) + + @pytest.mark.parametrize( "n,blocks,expected_chunks", [ diff --git a/sgkit/utils.py b/sgkit/utils.py index 81e846a18..0a59d1cae 100644 --- a/sgkit/utils.py +++ b/sgkit/utils.py @@ -1,6 +1,8 @@ +import warnings from typing import Any, List, Set, Tuple, Union import numpy as np +from xarray import Dataset from .typing import ArrayLike, DType @@ -100,6 +102,39 @@ def encode_array(x: ArrayLike) -> Tuple[ArrayLike, List[Any]]: return rank[inverse], names[index] +class MergeWarning(UserWarning): + """Warnings about merging datasets.""" + + pass + + +def merge_datasets(input: Dataset, output: Dataset) -> Dataset: + """Merge the input and output datasets into a new dataset, giving precedence to variables in the output. + + Parameters + ---------- + input : Dataset + The input dataset. + output : Dataset + The output dataset. + + Returns + ------- + Dataset + The merged dataset. If `input` and `output` have variables with the same name, a `MergeWarning` + is issued, and the variables from the `output` dataset are used. + """ + input_vars = {str(v) for v in input.data_vars.keys()} + output_vars = {str(v) for v in output.data_vars.keys()} + clobber_vars = sorted(list(input_vars & output_vars)) + if len(clobber_vars) > 0: + warnings.warn( + f"The following variables in the input dataset will be replaced in the output: {', '.join(clobber_vars)}", + MergeWarning, + ) + return output.merge(input, compat="override") + + def split_array_chunks(n: int, blocks: int) -> Tuple[int, ...]: """Compute chunk sizes for an array split into blocks. From 0e4cefef919c7a2c9bdf949d05538ffe3dd1b685 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 1 Sep 2020 12:35:36 +0100 Subject: [PATCH 4/4] Removed unused variable in test --- sgkit/tests/test_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sgkit/tests/test_utils.py b/sgkit/tests/test_utils.py index c2d2db124..02c54cf38 100644 --- a/sgkit/tests/test_utils.py +++ b/sgkit/tests/test_utils.py @@ -77,12 +77,8 @@ def test_encode_array( def test_merge_datasets(): ds = xr.Dataset(dict(x=xr.DataArray(da.zeros(100)))) - new_ds1 = xr.Dataset( - dict(y=xr.DataArray(da.zeros(100)), z=xr.DataArray(da.zeros(100))) - ) - new_ds2 = xr.Dataset( - dict(y=xr.DataArray(da.ones(100)), z=xr.DataArray(da.zeros(100))) - ) + new_ds1 = xr.Dataset(dict(y=xr.DataArray(da.zeros(100)))) + new_ds2 = xr.Dataset(dict(y=xr.DataArray(da.ones(100)))) ds = merge_datasets(ds, new_ds1) assert "y" in ds