Skip to content

Commit d652c34

Browse files
ravwojdylamergify[bot]
authored andcommitted
Update variable changes after recent popgen changes
1 parent ec4a160 commit d652c34

File tree

3 files changed

+146
-36
lines changed

3 files changed

+146
-36
lines changed

sgkit/stats/aggregation.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing_extensions import Literal
88
from xarray import Dataset
99

10-
from sgkit.stats.utils import assert_array_shape
1110
from sgkit import variables
11+
from sgkit.stats.utils import assert_array_shape
1212
from sgkit.typing import ArrayLike
1313
from sgkit.utils import conditional_merge_datasets
1414

@@ -86,7 +86,10 @@ def _count_cohort_alleles(
8686

8787

8888
def count_call_alleles(
89-
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
89+
ds: Dataset,
90+
*,
91+
call_genotype: Hashable = variables.call_genotype,
92+
merge: bool = True,
9093
) -> Dataset:
9194
"""Compute per sample allele counts from genotype calls.
9295
@@ -156,7 +159,10 @@ def count_call_alleles(
156159

157160

158161
def count_variant_alleles(
159-
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
162+
ds: Dataset,
163+
*,
164+
call_genotype: Hashable = variables.call_genotype,
165+
merge: bool = True,
160166
) -> Dataset:
161167
"""Compute allele count from genotype calls.
162168
@@ -213,14 +219,22 @@ def count_variant_alleles(
213219
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
214220

215221

216-
def count_cohort_alleles(ds: Dataset, merge: bool = True) -> Dataset:
222+
def count_cohort_alleles(
223+
ds: Dataset,
224+
*,
225+
call_genotype: Hashable = variables.call_genotype,
226+
merge: bool = True,
227+
) -> Dataset:
217228
"""Compute per cohort allele counts from genotype calls.
218229
219230
Parameters
220231
----------
221232
ds
222233
Genotype call dataset such as from
223234
`sgkit.create_genotype_call_dataset`.
235+
call_genotype
236+
Input variable name holding call_genotype as defined by
237+
:data:`sgkit.variables.call_genotype_spec`
224238
merge
225239
If True (the default), merge the input dataset and the computed
226240
output variables into a single dataset, otherwise return only
@@ -237,7 +251,7 @@ def count_cohort_alleles(ds: Dataset, merge: bool = True) -> Dataset:
237251
n_variants = ds.dims["variants"]
238252
n_alleles = ds.dims["alleles"]
239253

240-
ds = count_call_alleles(ds)
254+
ds = count_call_alleles(ds, call_genotype=call_genotype)
241255
AC, SC = da.asarray(ds.call_allele_count), da.asarray(ds.sample_cohort)
242256
n_cohorts = SC.max().compute() + 1 # 0-based indexing
243257
C = da.empty(n_cohorts, dtype=np.uint8)
@@ -255,8 +269,10 @@ def count_cohort_alleles(ds: Dataset, merge: bool = True) -> Dataset:
255269
AC = da.stack([AC.blocks[:, i] for i in range(AC.numblocks[1])]).sum(axis=0)
256270
assert_array_shape(AC, n_variants, n_cohorts, n_alleles)
257271

258-
new_ds = Dataset({"cohort_allele_count": (("variants", "cohorts", "alleles"), AC)})
259-
return conditional_merge_datasets(ds, new_ds, merge)
272+
new_ds = Dataset(
273+
{variables.cohort_allele_count: (("variants", "cohorts", "alleles"), AC)}
274+
)
275+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
260276

261277

262278
def _swap(dim: Dimension) -> Dimension:

sgkit/stats/popgen.py

Lines changed: 104 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@
99
from sgkit.typing import ArrayLike
1010
from sgkit.utils import conditional_merge_datasets
1111

12+
from .. import variables
1213
from .aggregation import count_cohort_alleles, count_variant_alleles
1314

1415

1516
def diversity(
16-
ds: Dataset, allele_counts: Hashable = "cohort_allele_count", merge: bool = True
17+
ds: Dataset,
18+
*,
19+
allele_counts: Hashable = variables.cohort_allele_count,
20+
call_genotype: Hashable = variables.call_genotype,
21+
merge: bool = True
1722
) -> Dataset:
1823
"""Compute diversity from cohort allele counts.
1924
@@ -31,19 +36,30 @@ def diversity(
3136
ds
3237
Genotype call dataset.
3338
allele_counts
34-
cohort allele counts to use or calculate.
39+
cohort allele counts to use or calculate. Defined by
40+
:data:`sgkit.variables.cohort_allele_count_spec`
41+
call_genotype
42+
Input variable name holding call_genotype as defined by
43+
:data:`sgkit.variables.call_genotype_spec`
44+
merge
45+
If True (the default), merge the input dataset and the computed
46+
output variables into a single dataset, otherwise return only
47+
the computed output variables.
48+
See :ref:`dataset_merge` for more details.
3549
3650
Returns
3751
-------
38-
diversity value.
52+
diversity value, as defined by :data:`sgkit.variables.stat_diversity_spec`.
3953
4054
Warnings
4155
--------
4256
This method does not currently support datasets that are chunked along the
4357
samples dimension.
4458
"""
4559
if allele_counts not in ds:
46-
ds = count_cohort_alleles(ds)
60+
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
61+
else:
62+
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
4763
ac = ds[allele_counts]
4864
an = ac.sum(axis=2)
4965
n_pairs = an * (an - 1) / 2
@@ -55,13 +71,13 @@ def diversity(
5571
pi_sum = pi.sum(axis=0, skipna=False)
5672
new_ds = Dataset(
5773
{
58-
"stat_diversity": (
74+
variables.stat_diversity: (
5975
"cohorts",
6076
pi_sum,
6177
)
6278
}
6379
)
64-
return conditional_merge_datasets(ds, new_ds, merge)
80+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
6581

6682

6783
# c = cohorts, k = alleles
@@ -100,7 +116,11 @@ def _divergence(ac: ArrayLike, an: ArrayLike, out: ArrayLike) -> None:
100116

101117

102118
def divergence(
103-
ds: Dataset, allele_counts: Hashable = "cohort_allele_count", merge: bool = True
119+
ds: Dataset,
120+
*,
121+
call_genotype: Hashable = variables.call_genotype,
122+
allele_counts: Hashable = variables.cohort_allele_count,
123+
merge: bool = True
104124
) -> Dataset:
105125
"""Compute divergence between pairs of cohorts.
106126
@@ -109,11 +129,21 @@ def divergence(
109129
ds
110130
Genotype call dataset.
111131
allele_counts
112-
cohort allele counts to use or calculate.
132+
cohort allele counts to use or calculate. Defined by
133+
:data:`sgkit.variables.cohort_allele_count_spec`
134+
call_genotype
135+
Input variable name holding call_genotype as defined by
136+
:data:`sgkit.variables.call_genotype_spec`
137+
merge
138+
If True (the default), merge the input dataset and the computed
139+
output variables into a single dataset, otherwise return only
140+
the computed output variables.
141+
See :ref:`dataset_merge` for more details.
113142
114143
Returns
115144
-------
116-
divergence value between pairs of cohorts.
145+
divergence value between pairs of cohorts, as defined by
146+
:data:`sgkit.variables.stat_divergence_spec`.
117147
118148
Warnings
119149
--------
@@ -122,7 +152,9 @@ def divergence(
122152
"""
123153

124154
if allele_counts not in ds:
125-
ds = count_cohort_alleles(ds)
155+
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
156+
else:
157+
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
126158
ac = ds[allele_counts]
127159
an = ac.sum(axis=2)
128160

@@ -137,8 +169,8 @@ def divergence(
137169
d_sum = d.sum(axis=0)
138170
assert_array_shape(d_sum, n_cohorts, n_cohorts)
139171

140-
new_ds = Dataset({"stat_divergence": (("cohorts_0", "cohorts_1"), d_sum)})
141-
return conditional_merge_datasets(ds, new_ds, merge)
172+
new_ds = Dataset({variables.stat_divergence: (("cohorts_0", "cohorts_1"), d_sum)})
173+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
142174

143175

144176
# c = cohorts
@@ -169,7 +201,11 @@ def _pairwise_sum(d: ArrayLike, out: ArrayLike) -> None:
169201

170202

171203
def Fst(
172-
ds: Dataset, allele_counts: Hashable = "cohort_allele_count", merge: bool = True
204+
ds: Dataset,
205+
*,
206+
call_genotype: Hashable = variables.call_genotype,
207+
allele_counts: Hashable = variables.cohort_allele_count,
208+
merge: bool = True
173209
) -> Dataset:
174210
"""Compute Fst between pairs of cohorts.
175211
@@ -178,21 +214,35 @@ def Fst(
178214
ds
179215
Genotype call dataset.
180216
allele_counts
181-
cohort allele counts to use or calculate.
217+
cohort allele counts to use or calculate. Defined by
218+
:data:`sgkit.variables.cohort_allele_count_spec`
219+
call_genotype
220+
Input variable name holding call_genotype as defined by
221+
:data:`sgkit.variables.call_genotype_spec`
222+
merge
223+
If True (the default), merge the input dataset and the computed
224+
output variables into a single dataset, otherwise return only
225+
the computed output variables.
226+
See :ref:`dataset_merge` for more details.
182227
183228
Returns
184229
-------
185-
Fst value between pairs of cohorts.
230+
Fst value between pairs of cohorts, as defined by
231+
:data:`sgkit.variables.stat_Fst_spec`.
186232
187233
Warnings
188234
--------
189235
This method does not currently support datasets that are chunked along the
190236
samples dimension.
191237
"""
192238
if allele_counts not in ds:
193-
ds = count_cohort_alleles(ds)
239+
ds = count_cohort_alleles(ds, call_genotype=call_genotype)
240+
else:
241+
variables.validate(ds, {allele_counts: variables.cohort_allele_count_spec})
194242
n_cohorts = ds.dims["cohorts"]
195-
div = diversity(ds, allele_counts, merge=False).stat_diversity
243+
div = diversity(
244+
ds, allele_counts=allele_counts, call_genotype=call_genotype, merge=False
245+
).stat_diversity
196246
assert_array_shape(div, n_cohorts)
197247

198248
# calculate diversity pairs
@@ -201,37 +251,60 @@ def Fst(
201251
div_pairs = da.map_blocks(_pairwise_sum, div, chunks=shape, dtype=np.float64)
202252
assert_array_shape(div_pairs, n_cohorts, n_cohorts)
203253

204-
gs = divergence(ds, allele_counts, merge=False).stat_divergence
254+
gs = divergence(
255+
ds, allele_counts=allele_counts, call_genotype=call_genotype, merge=False
256+
).stat_divergence
205257
den = div_pairs + 2 * gs
206258
fst = 1 - (2 * div_pairs / den)
207-
new_ds = Dataset({"stat_Fst": fst})
208-
return conditional_merge_datasets(ds, new_ds, merge)
259+
new_ds = Dataset({variables.stat_Fst: fst})
260+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
209261

210262

211263
def Tajimas_D(
212-
ds: Dataset, allele_counts: Hashable = "variant_allele_count", merge: bool = True
264+
ds: Dataset,
265+
*,
266+
call_genotype: Hashable = variables.call_genotype,
267+
variant_allele_counts: Hashable = variables.variant_allele_count,
268+
allele_counts: Hashable = variables.cohort_allele_count,
269+
merge: bool = True
213270
) -> Dataset:
214271
"""Compute Tajimas' D for a genotype call dataset.
215272
216273
Parameters
217274
----------
218275
ds
219276
Genotype call dataset.
277+
variant_allele_counts
278+
variant allele counts to use or calculate. Defined by
279+
:data:`sgkit.variables.variant_allele_counts_spec`
220280
allele_counts
221-
allele counts to use or calculate.
281+
cohort allele counts to use or calculate. Defined by
282+
:data:`sgkit.variables.cohort_allele_count_spec`
283+
call_genotype
284+
Input variable name holding call_genotype as defined by
285+
:data:`sgkit.variables.call_genotype_spec`
286+
merge
287+
If True (the default), merge the input dataset and the computed
288+
output variables into a single dataset, otherwise return only
289+
the computed output variables.
290+
See :ref:`dataset_merge` for more details.
222291
223292
Returns
224293
-------
225-
Tajimas' D value.
294+
Tajimas' D value, as defined by :data:`sgkit.variables.stat_Tajimas_D_spec`.
226295
227296
Warnings
228297
--------
229298
This method does not currently support datasets that are chunked along the
230299
samples dimension.
231300
"""
232-
if allele_counts not in ds:
233-
ds = count_variant_alleles(ds)
234-
ac = ds[allele_counts]
301+
if variant_allele_counts not in ds:
302+
ds = count_variant_alleles(ds, call_genotype=call_genotype)
303+
else:
304+
variables.validate(
305+
ds, {variant_allele_counts: variables.variant_allele_count_spec}
306+
)
307+
ac = ds[variant_allele_counts]
235308

236309
# count segregating
237310
S = ((ac > 0).sum(axis=1) > 1).sum()
@@ -246,7 +319,9 @@ def Tajimas_D(
246319
theta = S / a1
247320

248321
# calculate diversity
249-
div = diversity(ds).stat_diversity
322+
div = diversity(
323+
ds, allele_counts=allele_counts, call_genotype=call_genotype, merge=False
324+
).stat_diversity
250325

251326
# N.B., both theta estimates are usually divided by the number of
252327
# (accessible) bases but here we want the absolute difference
@@ -268,5 +343,5 @@ def Tajimas_D(
268343
# finally calculate Tajima's D
269344
D = d / d_stdev
270345

271-
new_ds = Dataset({"stat_Tajimas_D": D})
272-
return conditional_merge_datasets(ds, new_ds, merge)
346+
new_ds = Dataset({variables.stat_Tajimas_D: D})
347+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

sgkit/variables.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ def _check_field(
190190
ArrayLikeSpec("call_genotype_probability_mask", kind="b", ndim=3)
191191
)
192192
"""TODO"""
193+
cohort_allele_count, cohort_allele_count_spec = SgkitVariables.register_variable(
194+
ArrayLikeSpec("cohort_allele_count", kind="i", ndim=3)
195+
)
193196
covariates, covariates_spec = SgkitVariables.register_variable(
194197
ArrayLikeSpec("covariates", ndim={1, 2})
195198
)
@@ -244,6 +247,22 @@ def _check_field(
244247
ArrayLikeSpec("sample_pcs", ndim=2, kind="f")
245248
)
246249
"""Sample PCs (PCxS)."""
250+
stat_Fst, stat_Fst_spec = SgkitVariables.register_variable(
251+
ArrayLikeSpec("stat_Fst", ndim=2, kind="f")
252+
)
253+
"""TODO"""
254+
stat_divergence, stat_divergence_spec = SgkitVariables.register_variable(
255+
ArrayLikeSpec("stat_divergence", ndim=2, kind="f")
256+
)
257+
"""TODO"""
258+
stat_diversity, stat_diversity_spec = SgkitVariables.register_variable(
259+
ArrayLikeSpec("stat_diversity", ndim=1, kind="f")
260+
)
261+
"""TODO"""
262+
stat_Tajimas_D, stat_Tajimas_D_spec = SgkitVariables.register_variable(
263+
ArrayLikeSpec("stat_Tajimas_D", ndim={0, 1}, kind="f")
264+
)
265+
"""TODO"""
247266
traits, traits_spec = SgkitVariables.register_variable(
248267
ArrayLikeSpec("traits", ndim={1, 2})
249268
)

0 commit comments

Comments
 (0)