Skip to content

Eliminate automatically computed intermediate variables (for popgen) #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Methods
:toctree: generated/

count_call_alleles
count_cohort_alleles
count_variant_alleles
divergence
diversity
Expand Down
2 changes: 2 additions & 0 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ Xarray and Pandas operations in a single pipeline:
.pipe(lambda ds: ds.sel(variants=ds.variant_call_rate > .8))
# Assign a "cohort" variable that splits samples into two groups
.assign(sample_cohort=np.repeat([0, 1], ds.dims['samples'] // 2))
# Count alleles for each cohort
.pipe(sg.count_cohort_alleles)
# Compute Fst between the groups
.pipe(sg.Fst)
# Extract the Fst values for cohort pairs
Expand Down
8 changes: 7 additions & 1 deletion sgkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
create_genotype_call_dataset,
create_genotype_dosage_dataset,
)
from .stats.aggregation import count_call_alleles, count_variant_alleles, variant_stats
from .stats.aggregation import (
count_call_alleles,
count_cohort_alleles,
count_variant_alleles,
variant_stats,
)
from .stats.association import gwas_linear_regression
from .stats.hwe import hardy_weinberg_test
from .stats.pc_relate import pc_relate
Expand All @@ -26,6 +31,7 @@
"create_genotype_call_dataset",
"count_variant_alleles",
"count_call_alleles",
"count_cohort_alleles",
"create_genotype_dosage_dataset",
"display_genotypes",
"filter_partial_calls",
Expand Down
39 changes: 15 additions & 24 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sgkit.window import has_windows, window_statistic

from .. import variables
from .aggregation import count_cohort_alleles, count_variant_alleles


def diversity(
Expand All @@ -37,7 +36,7 @@ def diversity(
ds
Genotype call dataset.
allele_counts
cohort allele counts to use or calculate. Defined by
cohort allele counts to use. Defined by
:data:`sgkit.variables.cohort_allele_count_spec`
call_genotype
Input variable name holding call_genotype as defined by
Expand All @@ -57,10 +56,7 @@ def diversity(
This method does not currently support datasets that are chunked along the
samples dimension.
"""
if allele_counts not in ds:
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
else:
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
ac = ds[allele_counts]
an = ac.sum(axis=2)
n_pairs = an * (an - 1) / 2
Expand Down Expand Up @@ -162,7 +158,7 @@ def divergence(
ds
Genotype call dataset.
allele_counts
cohort allele counts to use or calculate. Defined by
cohort allele counts to use. Defined by
:data:`sgkit.variables.cohort_allele_count_spec`
call_genotype
Input variable name holding call_genotype as defined by
Expand All @@ -184,10 +180,7 @@ def divergence(
samples dimension.
"""

if allele_counts not in ds:
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
else:
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
ac = ds[allele_counts]

n_variants = ds.dims["variants"]
Expand Down Expand Up @@ -311,7 +304,7 @@ def Fst(
Other supported estimators include ``Nei`` (1986), (the same estimator
as tskit).
allele_counts
cohort allele counts to use or calculate. Defined by
cohort allele counts to use. Defined by
:data:`sgkit.variables.cohort_allele_count_spec`
call_genotype
Input variable name holding call_genotype as defined by
Expand All @@ -338,10 +331,7 @@ def Fst(
f"Estimator '{estimator}' is not a known estimator: {known_estimators.keys()}"
)
estimator = estimator or "Hudson"
if allele_counts not in ds:
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
else:
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})

n_cohorts = ds.dims["cohorts"]
gs = divergence(
Expand Down Expand Up @@ -371,10 +361,10 @@ def Tajimas_D(
ds
Genotype call dataset.
variant_allele_counts
variant allele counts to use or calculate. Defined by
variant allele counts to use. Defined by
:data:`sgkit.variables.variant_allele_counts_spec`
allele_counts
cohort allele counts to use or calculate. Defined by
cohort allele counts to use. Defined by
:data:`sgkit.variables.cohort_allele_count_spec`
call_genotype
Input variable name holding call_genotype as defined by
Expand All @@ -394,12 +384,13 @@ def Tajimas_D(
This method does not currently support datasets that are chunked along the
samples dimension.
"""
if variant_allele_counts not in ds:
ds = count_variant_alleles(ds, call_genotype=call_genotype)
else:
variables.validate(
ds, {variant_allele_counts: variables.variant_allele_count_spec}
)
variables.validate(
ds,
{
variant_allele_counts: variables.variant_allele_count_spec,
allele_counts: variables.cohort_allele_count_spec,
},
)
ac = ds[variant_allele_counts]

# count segregating
Expand Down
11 changes: 11 additions & 0 deletions sgkit/tests/test_popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sgkit import (
Fst,
Tajimas_D,
count_cohort_alleles,
count_variant_alleles,
create_genotype_call_dataset,
divergence,
Expand Down Expand Up @@ -57,6 +58,7 @@ def test_diversity(sample_size, chunks):
sample_cohorts = np.full_like(ts.samples(), 0)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
ds = ds.assign_coords({"cohorts": ["co_0"]})
ds = count_cohort_alleles(ds)
ds = diversity(ds)
div = ds.stat_diversity.sum(axis=0, skipna=False).sel(cohorts="co_0").values
ts_div = ts.diversity(span_normalise=False)
Expand All @@ -70,6 +72,7 @@ def test_diversity__windowed(sample_size):
sample_cohorts = np.full_like(ts.samples(), 0)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
ds = ds.assign_coords({"cohorts": ["co_0"]})
ds = count_cohort_alleles(ds)
ds = window(ds, size=25, step=25)
ds = diversity(ds)
div = ds["stat_diversity"].sel(cohorts="co_0").compute()
Expand Down Expand Up @@ -107,6 +110,7 @@ def test_divergence(sample_size, n_cohorts, chunks):
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
ds = divergence(ds)
div = ds.stat_divergence.sum(axis=0, skipna=False).values

Expand Down Expand Up @@ -136,6 +140,7 @@ def test_divergence__windowed(sample_size, n_cohorts, chunks):
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
ds = window(ds, size=25, step=25)
ds = divergence(ds)
div = ds["stat_divergence"].values
Expand Down Expand Up @@ -169,6 +174,7 @@ def test_divergence__windowed_scikit_allel_comparison(sample_size, n_cohorts, ch
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
ds = window(ds, size=25, step=25)
ds = divergence(ds)
div = ds["stat_divergence"].values
Expand Down Expand Up @@ -202,6 +208,7 @@ def test_Fst__Hudson(sample_size):
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
n_variants = ds.dims["variants"]
ds = window(ds, size=n_variants, step=n_variants) # single window
ds = Fst(ds, estimator="Hudson")
Expand Down Expand Up @@ -230,6 +237,7 @@ def test_Fst__Nei(sample_size, n_cohorts):
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
n_variants = ds.dims["variants"]
ds = window(ds, size=n_variants, step=n_variants) # single window
ds = Fst(ds, estimator="Nei")
Expand Down Expand Up @@ -266,6 +274,7 @@ def test_Fst__windowed(sample_size, n_cohorts, chunks):
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
cohort_names = [f"co_{i}" for i in range(n_cohorts)]
ds = ds.assign_coords({"cohorts_0": cohort_names, "cohorts_1": cohort_names})
ds = count_cohort_alleles(ds)
ds = window(ds, size=25, step=25)
fst_ds = Fst(ds, estimator="Nei")
fst = fst_ds["stat_Fst"].values
Expand Down Expand Up @@ -302,6 +311,8 @@ def test_Tajimas_D(sample_size):
ds = ts_to_dataset(ts) # type: ignore[no-untyped-call]
sample_cohorts = np.full_like(ts.samples(), 0)
ds["sample_cohort"] = xr.DataArray(sample_cohorts, dims="samples")
ds = count_variant_alleles(ds)
ds = count_cohort_alleles(ds)
n_variants = ds.dims["variants"]
ds = window(ds, size=n_variants, step=n_variants) # single window
ds = Tajimas_D(ds)
Expand Down