Skip to content

Commit 34779d2

Browse files
committed
Avoid divide by zero error
1 parent 7f9ce43 commit 34779d2

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

sgkit/stats/popgen.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ def diversity(
4646
n_pairs = an * (an - 1) / 2
4747
n_same = (ac * (ac - 1) / 2).sum(axis=2)
4848
n_diff = n_pairs - n_same
49-
pi = n_diff / n_pairs
50-
pi_sum = pi.sum(axis=0)
49+
# replace zeros to avoid divide by zero error
50+
n_pairs_na = n_pairs.where(n_pairs != 0)
51+
pi = n_diff / n_pairs_na
52+
pi_sum = pi.sum(axis=0, skipna=False)
5153
return pi_sum # type: ignore[no-any-return]
5254

5355

@@ -98,7 +100,7 @@ def Fst(ds: Dataset, allele_counts: Hashable = "cohort_allele_count",) -> DataAr
98100
DataArray
99101
fst value between the two cohorts.
100102
"""
101-
total_div = diversity(ds, allele_counts).sum()
103+
total_div = diversity(ds, allele_counts).sum(skipna=False)
102104
gs = divergence(ds, allele_counts)
103105
den = total_div + 2 * gs # type: ignore[operator]
104106
fst = 1 - (2 * total_div / den)

sgkit/tests/test_popgen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_divergence(size):
6060
np.testing.assert_allclose(div, ts_div)
6161

6262

63-
@pytest.mark.parametrize("size", [10, 100])
63+
@pytest.mark.parametrize("size", [2, 3, 10, 100])
6464
def test_Fst(size):
6565
ts = msprime.simulate(size, length=100, mutation_rate=0.05, random_seed=42)
6666
subset_1 = ts.samples()[: ts.num_samples // 2]

0 commit comments

Comments
 (0)