diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index b84ff1fe87..14decb18ff 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -287,7 +287,8 @@ verify_divergence_matrix(tsk_treeseq_t *ts, tsk_flags_t options) ts, n, sample_set_sizes, samples, n * n, index_tuples, 0, NULL, options, D1); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_divergence_matrix(ts, 0, NULL, 0, NULL, options, D2); + ret = tsk_treeseq_divergence_matrix( + ts, n, sample_set_sizes, samples, 0, NULL, options, D2); CU_ASSERT_EQUAL_FATAL(ret, 0); for (j = 0; j < n; j++) { @@ -1057,19 +1058,42 @@ test_single_tree_divergence_matrix(void) int ret; double result[16]; double D_branch[16] = { 0, 2, 6, 6, 2, 0, 6, 6, 6, 6, 0, 4, 6, 6, 4, 0 }; - double D_site[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + double D_site[16] = { 0, 1, 1, 0, 1, 0, 2, 1, 1, 2, 0, 1, 0, 1, 1, 0 }; - tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, - NULL, NULL, NULL, 0); + tsk_size_t sample_set_sizes[] = { 2, 2 }; - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D_branch); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D_site); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_set_sizes, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_set_sizes, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + sample_set_sizes[0] = 3; + sample_set_sizes[1] = 1; + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_set_sizes, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_divergence_matrix( + &ts, 2, sample_set_sizes, NULL, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + /* assert_arrays_almost_equal(4, result, D_site); */ + verify_divergence_matrix(&ts, TSK_STAT_BRANCH); verify_divergence_matrix(&ts, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE); verify_divergence_matrix(&ts, TSK_STAT_SITE); @@ -1083,7 +1107,7 @@ test_single_tree_divergence_matrix_internal_samples(void) { tsk_treeseq_t ts; int ret; - double result[16]; + double *result = malloc(16 * sizeof(double)); double D[16] = { 0, 2, 4, 3, 2, 0, 4, 3, 4, 4, 0, 1, 3, 3, 1, 0 }; const char *nodes = "1 0 -1 -1\n" /* 2.00┊ 6 ┊ */ @@ -1109,14 +1133,35 @@ test_single_tree_divergence_matrix_internal_samples(void) "3 3 T -1\n" "4 4 T -1\n" "5 5 T -1\n"; + tsk_id_t samples[] = { 0, 1, 2, 5 }; + tsk_size_t sizes[] = { 1, 1, 1, 1 }; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 4, sizes, samples, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + ret = tsk_treeseq_divergence_matrix( + &ts, 4, sizes, samples, 0, NULL, TSK_STAT_SITE, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + + ret = tsk_treeseq_divergence_matrix( + &ts, 4, NULL, samples, 0, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(16, result, D); + ret = tsk_treeseq_divergence_matrix( + &ts, 4, NULL, samples, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D); @@ -1126,6 +1171,7 @@ test_single_tree_divergence_matrix_internal_samples(void) verify_divergence_matrix(&ts, TSK_STAT_SITE | TSK_STAT_SPAN_NORMALISE); tsk_treeseq_free(&ts); + free(result); } static void @@ -1164,7 +1210,8 @@ test_single_tree_divergence_matrix_multi_root(void) tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(16, result, D_branch); @@ -2424,73 +2471,84 @@ test_simplest_divergence_matrix(void) "1 0 0\n" "0 1 0\n"; const char *edges = "0 1 2 0,1\n"; + const char *sites = "0.1 A\n" + "0.6 A\n"; + const char *mutations = "0 0 B -1\n" + "1 0 B -1\n"; tsk_treeseq_t ts; tsk_id_t sample_ids[] = { 0, 1 }; double D_branch[4] = { 0, 2, 2, 0 }; - double D_site[4] = { 0, 0, 0, 0 }; + double D_site[4] = { 0, 2, 2, 0 }; double result[4]; int ret; - tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + &ts, 2, NULL, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_branch, result); - ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, + TSK_STAT_BRANCH | TSK_STAT_SPAN_NORMALISE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_branch, result); - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_site, result); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); + &ts, 2, NULL, sample_ids, 0, NULL, TSK_STAT_SPAN_NORMALISE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_site, result); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 0, NULL, TSK_STAT_SITE, result); + &ts, 2, NULL, sample_ids, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_site, result); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_branch, result); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_SITE, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_site, result); - ret = tsk_treeseq_divergence_matrix(&ts, 0, NULL, 0, NULL, TSK_STAT_NODE, result); + ret = tsk_treeseq_divergence_matrix( + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_NODE, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); ret = tsk_treeseq_divergence_matrix( - &ts, 0, NULL, 0, NULL, TSK_STAT_POLARISED, result); + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_POLARISED, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_POLARISED_UNSUPPORTED); ret = tsk_treeseq_divergence_matrix( - &ts, 0, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + &ts, 0, NULL, NULL, 0, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); sample_ids[0] = -1; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_ids[0] = 3; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_ids[0] = 1; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + &ts, 2, NULL, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); + sample_ids[0] = 2; + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLES); + tsk_treeseq_free(&ts); } @@ -2515,39 +2573,39 @@ test_simplest_divergence_matrix_windows(void) tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 2, windows, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(8, D_site, result); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 2, windows, TSK_STAT_BRANCH, result); + &ts, 2, NULL, sample_ids, 2, windows, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(8, D_branch, result); /* Windows for the second half */ ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 1, windows + 1, TSK_STAT_SITE, result); + &ts, 2, NULL, sample_ids, 1, windows + 1, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_site, result); ret = tsk_treeseq_divergence_matrix( - &ts, 2, sample_ids, 1, windows + 1, TSK_STAT_BRANCH, result); + &ts, 2, NULL, sample_ids, 1, windows + 1, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(4, D_branch, result); - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, windows, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 0, windows, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); windows[0] = -1; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 2, windows, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = 0.45; windows[2] = 1.5; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 2, windows, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = 0.55; windows[2] = 1.0; - ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 2, windows, 0, result); + ret = tsk_treeseq_divergence_matrix(&ts, 2, NULL, sample_ids, 2, windows, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); tsk_treeseq_free(&ts); @@ -2558,7 +2616,7 @@ test_simplest_divergence_matrix_internal_sample(void) { const char *nodes = "1 0 0\n" "1 0 0\n" - "0 1 0\n"; + "1 1 0\n"; const char *edges = "0 1 2 0,1\n"; tsk_treeseq_t ts; tsk_id_t sample_ids[] = { 0, 1, 2 }; @@ -2570,12 +2628,12 @@ test_simplest_divergence_matrix_internal_sample(void) tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); ret = tsk_treeseq_divergence_matrix( - &ts, 3, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); + &ts, 3, NULL, sample_ids, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(9, D_branch, result); ret = tsk_treeseq_divergence_matrix( - &ts, 3, sample_ids, 0, NULL, TSK_STAT_SITE, result); + &ts, 3, NULL, sample_ids, 0, NULL, TSK_STAT_SITE, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(9, D_site, result); diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index a9fee8a8ff..d5dce3cd93 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -8008,9 +8008,10 @@ test_time_uncalibrated(void) TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, sigma); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, TSK_STAT_BRANCH, result); + ret = tsk_treeseq_divergence_matrix( + &ts2, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_UNCALIBRATED); - ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, 0, NULL, + ret = tsk_treeseq_divergence_matrix(&ts2, 0, NULL, NULL, 0, NULL, TSK_STAT_BRANCH | TSK_STAT_ALLOW_TIME_UNCALIBRATED, result); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 043ae5ceab..8dca6d0720 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -966,16 +966,6 @@ tskit_suite_init(void) return CUE_SUCCESS; } -void -assert_arrays_almost_equal(tsk_size_t len, double *a, double *b) -{ - tsk_size_t j; - - for (j = 0; j < len; j++) { - CU_ASSERT_DOUBLE_EQUAL(a[j], b[j], 1e-9); - } -} - static int tskit_suite_cleanup(void) { diff --git a/c/tests/testlib.h b/c/tests/testlib.h index 69efb14781..4885fba356 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -54,7 +54,16 @@ void parse_individuals(const char *text, tsk_individual_table_t *individual_tabl void unsort_edges(tsk_edge_table_t *edges, size_t start); -void assert_arrays_almost_equal(tsk_size_t len, double *a, double *b); +/* Use a macro so we can get line numbers at roughly the right place */ +#define assert_arrays_almost_equal(len, a, b) \ + { \ + do { \ + tsk_size_t _j; \ + for (_j = 0; _j < len; _j++) { \ + CU_ASSERT_DOUBLE_EQUAL(a[_j], b[_j], 1e-9); \ + } \ + } while (0); \ + } extern const char *single_tree_ex_nodes; extern const char *single_tree_ex_edges; diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 56991b5f40..56b2661c12 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2019-2023 Tskit Developers + * Copyright (c) 2019-2024 Tskit Developers * Copyright (c) 2015-2018 University of Oxford * * Permission is hereby granted, free of charge, to any person obtaining a copy @@ -7288,39 +7288,21 @@ sv_tables_mrca(const sv_tables_t *self, tsk_id_t x, tsk_id_t y) } static int -tsk_treeseq_check_node_bounds( - const tsk_treeseq_t *self, tsk_size_t num_nodes, const tsk_id_t *nodes) -{ - int ret = 0; - tsk_size_t j; - tsk_id_t u; - const tsk_id_t N = (tsk_id_t) self->tables->nodes.num_rows; - - for (j = 0; j < num_nodes; j++) { - u = nodes[j]; - if (u < 0 || u >= N) { - ret = TSK_ERR_NODE_OUT_OF_BOUNDS; - goto out; - } - } -out: - return ret; -} - -static int -tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_samples, - const tsk_id_t *restrict samples, tsk_size_t num_windows, +tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *restrict sample_set_sizes, + const tsk_id_t *restrict sample_sets, tsk_size_t num_windows, const double *restrict windows, tsk_flags_t options, double *restrict result) { int ret = 0; tsk_tree_t tree; const double *restrict nodes_time = self->tables->nodes.time; - const tsk_size_t n = num_samples; - tsk_size_t i, j, k; + const tsk_size_t N = num_sample_sets; + tsk_size_t i, j, k, offset, sj, sk; tsk_id_t u, v, w, u_root, v_root; double tu, tv, d, span, left, right, span_left, span_right; double *restrict D; sv_tables_t sv; + tsk_size_t *ss_offsets = tsk_malloc((num_sample_sets + 1) * sizeof(*ss_offsets)); memset(&sv, 0, sizeof(sv)); ret = tsk_tree_init(&tree, self, 0); @@ -7331,16 +7313,26 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s if (ret != 0) { goto out; } - + if (ss_offsets == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } if (self->time_uncalibrated && !(options & TSK_STAT_ALLOW_TIME_UNCALIBRATED)) { ret = TSK_ERR_TIME_UNCALIBRATED; goto out; } + ss_offsets[0] = 0; + offset = 0; + for (j = 0; j < N; j++) { + offset += sample_set_sizes[j]; + ss_offsets[j + 1] = offset; + } + for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; - D = result + i * n * n; + D = result + i * N * N; ret = tsk_tree_seek(&tree, left, 0); if (ret != 0) { goto out; @@ -7350,24 +7342,34 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s span_right = TSK_MIN(tree.interval.right, right); span = span_right - span_left; sv_tables_build(&sv, &tree); - for (j = 0; j < n; j++) { - u = samples[j]; - for (k = j + 1; k < n; k++) { - v = samples[k]; - w = sv_tables_mrca(&sv, u, v); - if (w != TSK_NULL) { - u_root = w; - v_root = w; - } else { - /* Slow path - only happens for nodes in disconnected - * subtrees in a tree with multiple roots */ - u_root = tsk_tree_get_node_root(&tree, u); - v_root = tsk_tree_get_node_root(&tree, v); + for (sj = 0; sj < N; sj++) { + for (j = ss_offsets[sj]; j < ss_offsets[sj + 1]; j++) { + u = sample_sets[j]; + for (sk = sj; sk < N; sk++) { + for (k = ss_offsets[sk]; k < ss_offsets[sk + 1]; k++) { + v = sample_sets[k]; + if (u == v) { + /* This case contributes zero to divergence, so + * short-circuit to save time. + * TODO is there a better way to do this? */ + continue; + } + w = sv_tables_mrca(&sv, u, v); + if (w != TSK_NULL) { + u_root = w; + v_root = w; + } else { + /* Slow path - only happens for nodes in disconnected + * subtrees in a tree with multiple roots */ + u_root = tsk_tree_get_node_root(&tree, u); + v_root = tsk_tree_get_node_root(&tree, v); + } + tu = nodes_time[u_root] - nodes_time[u]; + tv = nodes_time[v_root] - nodes_time[v]; + d = (tu + tv) * span; + D[sj * N + sk] += d; + } } - tu = nodes_time[u_root] - nodes_time[u]; - tv = nodes_time[v_root] - nodes_time[v]; - d = (tu + tv) * span; - D[j * n + k] += d; } } ret = tsk_tree_next(&tree); @@ -7380,6 +7382,7 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s out: tsk_tree_free(&tree); sv_tables_free(&sv); + tsk_safe_free(ss_offsets); return ret; } @@ -7391,14 +7394,13 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s static void update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, - const tsk_size_t *restrict offsets, double *D) + const tsk_size_t *restrict offsets, const tsk_size_t num_sample_sets, double *D) { const tsk_size_t num_alleles = var->num_alleles; - const tsk_id_t n = (tsk_id_t) var->num_samples; - tsk_size_t a, b, j, k; tsk_id_t u, v; + double increment; for (a = 0; a < num_alleles; a++) { for (b = a + 1; b < num_alleles; b++) { @@ -7409,10 +7411,14 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, /* Only increment the upper triangle to (hopefully) improve memory * access patterns */ if (u > v) { - v = A[j]; u = A[k]; + v = A[j]; + } + increment = 1; + if (u == v) { + increment = 2; } - D[u * n + v]++; + D[u * (tsk_id_t) num_sample_sets + v] += increment; } } } @@ -7441,8 +7447,23 @@ group_alleles(const tsk_variant_t *var, tsk_id_t *restrict A, tsk_size_t *offset } } +static void +remap_to_sample_sets(const tsk_size_t num_samples, const tsk_id_t *restrict samples, + const tsk_id_t *restrict sample_set_index_map, tsk_id_t *restrict A) +{ + tsk_size_t j; + tsk_id_t u; + for (j = 0; j < num_samples; j++) { + u = samples[A[j]]; + tsk_bug_assert(u >= 0); + tsk_bug_assert(sample_set_index_map[u] >= 0); + A[j] = sample_set_index_map[u]; + } +} + static int -tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples, +tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_id_t *restrict sample_set_index_map, const tsk_size_t num_samples, const tsk_id_t *restrict samples, tsk_size_t num_windows, const double *restrict windows, tsk_flags_t TSK_UNUSED(options), double *restrict result) @@ -7460,6 +7481,8 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam tsk_size_t *allele_offsets = NULL; tsk_variant_t variant; + /* FIXME it's not clear that using TSK_ISOLATED_NOT_MISSING is + * correct here */ ret = tsk_variant_init( &variant, self, samples, num_samples, NULL, TSK_ISOLATED_NOT_MISSING); if (ret != 0) { @@ -7478,7 +7501,7 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam for (i = 0; i < num_windows; i++) { left = windows[i]; right = windows[i + 1]; - D = result + i * num_samples * num_samples; + D = result + i * num_sample_sets * num_sample_sets; if (site_id < num_sites) { tsk_bug_assert(sites_position[site_id] >= left); @@ -7500,7 +7523,8 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam } } group_alleles(&variant, A, allele_offsets); - update_site_divergence(&variant, A, allele_offsets, D); + remap_to_sample_sets(num_samples, samples, sample_set_index_map, A); + update_site_divergence(&variant, A, allele_offsets, num_sample_sets, D); site_id++; } } @@ -7512,50 +7536,73 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam return ret; } +/* Return the mapping from node IDs to the index of the sample set + * they belong to, or -1 of none. Error if a node is in more than one + * set. + */ static int -get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples, - const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map) +get_sample_set_index_map(const tsk_treeseq_t *self, const tsk_size_t num_sample_sets, + const tsk_size_t *restrict sample_set_sizes, const tsk_id_t *restrict sample_sets, + tsk_size_t *ret_total_samples, tsk_id_t *restrict node_index_map) { int ret = 0; - tsk_size_t j; + tsk_size_t i, j, k; tsk_id_t u; - tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map)); - - if (sample_index_map == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - /* Assign the output pointer here so that it will be freed in the case - * of an error raised in the input checking */ - *ret_sample_index_map = sample_index_map; + tsk_size_t total_samples = 0; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + const tsk_flags_t *restrict node_flags = self->tables->nodes.flags; for (j = 0; j < num_nodes; j++) { - sample_index_map[j] = TSK_NULL; + node_index_map[j] = TSK_NULL; } - for (j = 0; j < num_samples; j++) { - u = samples[j]; - if (sample_index_map[u] != TSK_NULL) { - ret = TSK_ERR_DUPLICATE_SAMPLE; - goto out; + i = 0; + for (j = 0; j < num_sample_sets; j++) { + total_samples += sample_set_sizes[j]; + for (k = 0; k < sample_set_sizes[j]; k++) { + u = sample_sets[i]; + i++; + if (u < 0 || u >= (tsk_id_t) num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + /* Note: we require nodes to be samples because we have to think + * about how to normalise by the length of genome that the node + * is 'in' the tree for each window otherwise. */ + if (!(node_flags[u] & TSK_NODE_IS_SAMPLE)) { + ret = TSK_ERR_BAD_SAMPLES; + goto out; + } + if (node_index_map[u] != TSK_NULL) { + ret = TSK_ERR_DUPLICATE_SAMPLE; + goto out; + } + node_index_map[u] = (tsk_id_t) j; } - sample_index_map[u] = (tsk_id_t) j; } + *ret_total_samples = total_samples; out: return ret; } static void -fill_lower_triangle( - double *restrict result, const tsk_size_t n, const tsk_size_t num_windows) +fill_lower_triangle_count_normalise(const tsk_size_t num_windows, const tsk_size_t n, + const tsk_size_t *set_sizes, double *restrict result) { tsk_size_t i, j, k; + double denom; double *restrict D; /* TODO there's probably a better striding pattern that could be used here */ for (i = 0; i < num_windows; i++) { D = result + i * n * n; for (j = 0; j < n; j++) { + denom = (double) set_sizes[j] * (double) (set_sizes[j] - 1); + if (denom != 0) { + D[j * n + j] /= denom; + } for (k = j + 1; k < n; k++) { + denom = (double) set_sizes[j] * (double) set_sizes[k]; + D[j * n + k] /= denom; D[k * n + j] = D[j * n + k]; } } @@ -7563,19 +7610,23 @@ fill_lower_triangle( } int -tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, - const tsk_id_t *samples_in, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result) +tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_sample_sets_in, + const tsk_size_t *sample_set_sizes_in, const tsk_id_t *sample_sets_in, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result) { int ret = 0; - const tsk_id_t *samples = self->samples; - tsk_size_t n = self->num_samples; + tsk_size_t N, total_samples; + const tsk_size_t *sample_set_sizes; + const tsk_id_t *sample_sets; + tsk_size_t *tmp_sample_set_sizes = NULL; const double default_windows[] = { 0, self->tables->sequence_length }; const tsk_size_t num_nodes = self->tables->nodes.num_rows; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); bool stat_node = !!(options & TSK_STAT_NODE); - tsk_id_t *sample_index_map = NULL; + tsk_id_t *sample_set_index_map + = tsk_malloc(num_nodes * sizeof(*sample_set_index_map)); + tsk_size_t j; if (stat_node) { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; @@ -7606,42 +7657,57 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, } } - if (samples_in != NULL) { - samples = samples_in; - n = num_samples; - ret = tsk_treeseq_check_node_bounds(self, n, samples); - if (ret != 0) { + /* If sample_sets is NULL, use self->samples and ignore input + * num_sample_sets */ + sample_sets = sample_sets_in; + N = num_sample_sets_in; + if (sample_sets_in == NULL) { + sample_sets = self->samples; + if (sample_set_sizes_in == NULL) { + N = self->num_samples; + } + } + sample_set_sizes = sample_set_sizes_in; + /* If sample_set_sizes is NULL, assume its N 1S */ + if (sample_set_sizes_in == NULL) { + tmp_sample_set_sizes = tsk_malloc(N * sizeof(*tmp_sample_set_sizes)); + if (tmp_sample_set_sizes == NULL) { + ret = TSK_ERR_NO_MEMORY; goto out; } + for (j = 0; j < N; j++) { + tmp_sample_set_sizes[j] = 1; + } + sample_set_sizes = tmp_sample_set_sizes; } - /* NOTE: we're just using this here to check the input for duplicates. - */ - ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map); + ret = get_sample_set_index_map( + self, N, sample_set_sizes, sample_sets, &total_samples, sample_set_index_map); if (ret != 0) { goto out; } - tsk_memset(result, 0, num_windows * n * n * sizeof(*result)); + tsk_memset(result, 0, num_windows * N * N * sizeof(*result)); if (stat_branch) { - ret = tsk_treeseq_divergence_matrix_branch( - self, n, samples, num_windows, windows, options, result); + ret = tsk_treeseq_divergence_matrix_branch(self, N, sample_set_sizes, + sample_sets, num_windows, windows, options, result); } else { tsk_bug_assert(stat_site); - ret = tsk_treeseq_divergence_matrix_site( - self, n, samples, num_windows, windows, options, result); + ret = tsk_treeseq_divergence_matrix_site(self, N, sample_set_index_map, + total_samples, sample_sets, num_windows, windows, options, result); } if (ret != 0) { goto out; } - fill_lower_triangle(result, n, num_windows); + fill_lower_triangle_count_normalise(num_windows, N, sample_set_sizes, result); if (options & TSK_STAT_SPAN_NORMALISE) { - span_normalise(num_windows, windows, n * n, result); + span_normalise(num_windows, windows, N * N, result); } out: - tsk_safe_free(sample_index_map); + tsk_safe_free(sample_set_index_map); + tsk_safe_free(tmp_sample_set_sizes); return ret; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 7bc1e60092..dbc870ad2f 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1098,9 +1098,9 @@ int tsk_treeseq_f4(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples, - const tsk_id_t *samples, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result); +int tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); /****************************************************************************/ /* Tree */ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index ea42f98b5b..b32967c6ff 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -802,11 +802,16 @@ parse_sample_sets(PyObject *sample_set_sizes, PyArrayObject **ret_sample_set_siz } shape = PyArray_DIMS(sample_set_sizes_array); num_sample_sets = shape[0]; + /* The sum of the lengths in sample_set_sizes must be equal to the length * of the sample_sets array */ sum = 0; a = PyArray_DATA(sample_set_sizes_array); for (j = 0; j < num_sample_sets; j++) { + if (sum + a[j] < sum) { + PyErr_SetString(PyExc_ValueError, "Overflow in sample set sizes sum"); + goto out; + } sum += a[j]; } @@ -9777,44 +9782,47 @@ static PyObject * TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) { PyObject *ret = NULL; - static char *kwlist[] = { "windows", "samples", "mode", "span_normalise", NULL }; - PyArrayObject *result_array = NULL; - PyObject *windows = NULL; - PyObject *py_samples = Py_None; + + static char *kwlist[] = { "windows", "sample_set_sizes", "sample_sets", "mode", + "span_normalise", NULL }; char *mode = NULL; + PyArrayObject *result_array = NULL; + PyObject *py_sample_set_sizes = Py_None; + PyObject *py_sample_sets = Py_None; + PyObject *py_windows = Py_None; PyArrayObject *windows_array = NULL; - PyArrayObject *samples_array = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; tsk_flags_t options = 0; - npy_intp *shape, dims[3]; - tsk_size_t num_samples, num_windows; - tsk_id_t *samples = NULL; + npy_intp dims[3]; + tsk_size_t num_sample_sets = 0; + tsk_size_t num_windows = 0; + tsk_id_t *sample_sets = NULL; + tsk_size_t *sample_set_sizes = NULL; int span_normalise = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|Osi", kwlist, &windows, &py_samples, - &mode, &span_normalise)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|si", kwlist, &py_windows, + &py_sample_set_sizes, &py_sample_sets, &mode, &span_normalise)) { goto out; } - num_samples = tsk_treeseq_get_num_samples(self->tree_sequence); - if (py_samples != Py_None) { - samples_array = (PyArrayObject *) PyArray_FROMANY( - py_samples, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (samples_array == NULL) { - goto out; - } - shape = PyArray_DIMS(samples_array); - samples = PyArray_DATA(samples_array); - num_samples = (tsk_size_t) shape[0]; + + if (parse_sample_sets(py_sample_set_sizes, &sample_set_sizes_array, py_sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; } - if (parse_windows(windows, &windows_array, &num_windows) != 0) { + sample_set_sizes = PyArray_DATA(sample_set_sizes_array); + sample_sets = PyArray_DATA(sample_sets_array); + if (parse_windows(py_windows, &windows_array, &num_windows) != 0) { goto out; } dims[0] = num_windows; - dims[1] = num_samples; - dims[2] = num_samples; + dims[1] = num_sample_sets; + dims[2] = num_sample_sets; result_array = (PyArrayObject *) PyArray_SimpleNew(3, dims, NPY_FLOAT64); if (result_array == NULL) { goto out; @@ -9831,7 +9839,7 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd Py_BEGIN_ALLOW_THREADS err = tsk_treeseq_divergence_matrix( self->tree_sequence, - num_samples, samples, + num_sample_sets, sample_set_sizes, sample_sets, num_windows, PyArray_DATA(windows_array), options, PyArray_DATA(result_array)); Py_END_ALLOW_THREADS @@ -9845,9 +9853,10 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd ret = (PyObject *) result_array; result_array = NULL; out: - Py_XDECREF(result_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(sample_sets_array); Py_XDECREF(windows_array); - Py_XDECREF(samples_array); + Py_XDECREF(result_array); return ret; } diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index bf5118a5d4..ea83cc560d 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023 Tskit Developers +# Copyright (c) 2023-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,10 @@ """ Test cases for divergence matrix based pairwise stats """ +import array +import collections +import functools + import msprime import numpy as np import pytest @@ -161,15 +165,29 @@ def span_normalise_windows(D, windows): D[j] /= span -def branch_divergence_matrix(ts, windows=None, samples=None, span_normalise=True): +def sample_set_normalisation(sample_sets): + n = len(sample_sets) + C = np.zeros((n, n)) + for j in range(n): + C[j, j] = len(sample_sets[j]) * (len(sample_sets[j]) - 1) + for k in range(j + 1, n): + C[j, k] = len(sample_sets[j]) * len(sample_sets[k]) + C[k, j] = C[j, k] + # Avoid division by zero for singleton samplesets + C[C == 0] = 1 + # print("C = ", C) + return C + + +def branch_divergence_matrix(ts, sample_sets=None, windows=None, span_normalise=True): windows_specified = windows is not None windows = ts.parse_windows(windows) num_windows = len(windows) - 1 - samples = ts.samples() if samples is None else samples - n = len(samples) + n = len(sample_sets) D = np.zeros((num_windows, n, n)) tree = tskit.Tree(ts) + C = sample_set_normalisation(sample_sets) for i in range(num_windows): left = windows[i] right = windows[i + 1] @@ -183,23 +201,33 @@ def branch_divergence_matrix(ts, windows=None, samples=None, span_normalise=True # print(f"\ttree {tree.interval} [{span_left}, {span_right})") tables = sv_tables_init(tree.parent_array) for j in range(n): - u = samples[j] - for k in range(j + 1, n): - v = samples[k] - w = sv_mrca(u, v, *tables) - assert w == tree.mrca(u, v) - if w != tskit.NULL: - tu = ts.nodes_time[w] - ts.nodes_time[u] - tv = ts.nodes_time[w] - ts.nodes_time[v] - else: - tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] - tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] - d = (tu + tv) * span - D[i, j, k] += d + for u in sample_sets[j]: + for k in range(j, n): + for v in sample_sets[k]: + # The u=v case here contributes zero, not bothering + # to exclude it. + w = sv_mrca(u, v, *tables) + assert w == tree.mrca(u, v) + if w != tskit.NULL: + tu = ts.nodes_time[w] - ts.nodes_time[u] + tv = ts.nodes_time[w] - ts.nodes_time[v] + else: + tu = ( + ts.nodes_time[local_root(tree, u)] + - ts.nodes_time[u] + ) + tv = ( + ts.nodes_time[local_root(tree, v)] + - ts.nodes_time[v] + ) + d = (tu + tv) * span + D[i, j, k] += d tree.next() - # Fill out symmetric triangle in the matrix + # Fill out symmetric triangle in the matrix, and get average for j in range(n): + D[i, j, j] /= C[j, j] for k in range(j + 1, n): + D[i, j, k] /= C[j, k] D[i, k, j] = D[i, j, k] if span_normalise: span_normalise_windows(D, windows) @@ -208,27 +236,61 @@ def branch_divergence_matrix(ts, windows=None, samples=None, span_normalise=True return D -def divergence_matrix(ts, windows=None, samples=None, mode="site", span_normalise=True): +def divergence_matrix( + ts, windows=None, sample_sets=None, samples=None, mode="site", span_normalise=True +): assert mode in ["site", "branch"] + if samples is not None and sample_sets is not None: + raise ValueError("Cannot specify both") + if samples is None and sample_sets is None: + samples = ts.samples() + if samples is not None: + sample_sets = [[u] for u in samples] + else: + assert sample_sets is not None + if mode == "site": return site_divergence_matrix( - ts, samples=samples, windows=windows, span_normalise=span_normalise + ts, sample_sets, windows=windows, span_normalise=span_normalise ) else: return branch_divergence_matrix( - ts, samples=samples, windows=windows, span_normalise=span_normalise + ts, sample_sets, windows=windows, span_normalise=span_normalise ) -def stats_api_divergence_matrix( - ts, windows=None, samples=None, mode="site", span_normalise=True +def stats_api_divergence_matrix(ts, *args, **kwargs): + return stats_api_matrix_method(ts, ts.divergence, *args, **kwargs) + + +def stats_api_genetic_relatedness_matrix(ts, *args, **kwargs): + method = functools.partial(ts.genetic_relatedness, proportion=False) + return stats_api_matrix_method(ts, method, *args, **kwargs) + + +def stats_api_matrix_method( + ts, + method, + windows=None, + samples=None, + sample_sets=None, + mode="site", + span_normalise=True, ): - samples = ts.samples() if samples is None else samples + if samples is not None and sample_sets is not None: + raise ValueError("Cannot specify both") + if samples is None and sample_sets is None: + samples = ts.samples() + if samples is not None: + sample_sets = [[u] for u in samples] + else: + assert sample_sets is not None + windows_specified = windows is not None windows = [0, ts.sequence_length] if windows is None else list(windows) num_windows = len(windows) - 1 - if len(samples) == 0: + if len(sample_sets) == 0: # FIXME: the code general stat code doesn't seem to handle zero samples # case, need to identify MWE and file issue. if windows_specified: @@ -236,17 +298,6 @@ def stats_api_divergence_matrix( else: return np.zeros(shape=(0, 0)) - # Make sure that all the specified samples have the sample flag set, otherwise - # the library code will complain - tables = ts.dump_tables() - flags = tables.nodes.flags - # NOTE: this is a shortcut, setting all flags unconditionally to zero, so don't - # use this tree sequence outside this method. - flags[:] = 0 - flags[samples] = tskit.NODE_IS_SAMPLE - tables.nodes.flags = flags - ts = tables.tree_sequence() - # FIXME We have to go through this annoying rigmarole because windows must start and # end with 0 and L. We should relax this requirement to just making the windows # contiguous, so that we just look at specific sections of the genome. @@ -258,10 +309,9 @@ def stats_api_divergence_matrix( windows.append(ts.sequence_length) drop.append(-1) - n = len(samples) - sample_sets = [[u] for u in samples] + n = len(sample_sets) indexes = [(i, j) for i in range(n) for j in range(n)] - X = ts.divergence( + X = method( sample_sets, indexes=indexes, mode=mode, @@ -271,9 +321,9 @@ def stats_api_divergence_matrix( keep = np.ones(len(windows) - 1, dtype=bool) keep[drop] = False X = X[keep] + # Quick hack to get the within singleton sampleset divergence=0 + X[np.isnan(X)] = 0 out = X.reshape((X.shape[0], n, n)) - for D in out: - np.fill_diagonal(D, 0) if not windows_specified: out = out[0] return out @@ -294,16 +344,21 @@ def group_alleles(genotypes, num_alleles): return A, offsets -def site_divergence_matrix(ts, windows=None, samples=None, span_normalise=True): +def site_divergence_matrix(ts, sample_sets, *, windows=None, span_normalise=True): windows_specified = windows is not None windows = ts.parse_windows(windows) num_windows = len(windows) - 1 - samples = ts.samples() if samples is None else samples - n = len(samples) - sample_index_map = np.zeros(ts.num_nodes, dtype=int) - 1 - sample_index_map[samples] = np.arange(n) + n = len(sample_sets) + samples = [] + sample_set_index_map = [] + for j in range(n): + for u in sample_sets[j]: + samples.append(u) + sample_set_index_map.append(j) + C = sample_set_normalisation(sample_sets) D = np.zeros((num_windows, n, n)) + site_id = 0 while site_id < ts.num_sites and ts.sites_position[site_id] < windows[0]: site_id += 1 @@ -324,10 +379,13 @@ def site_divergence_matrix(ts, windows=None, samples=None, span_normalise=True): for k in range(j + 1, variant.num_alleles): B = X[offsets[k] : offsets[k + 1]] for a in A: + a_set_index = sample_set_index_map[a] for b in B: - D[i, a, b] += 1 - D[i, b, a] += 1 + b_set_index = sample_set_index_map[b] + D[i, a_set_index, b_set_index] += 1 + D[i, b_set_index, a_set_index] += 1 site_id += 1 + D[i] /= C if span_normalise: span_normalise_windows(D, windows) if not windows_specified: @@ -340,27 +398,32 @@ def check_divmat( *, windows=None, samples=None, + sample_sets=None, span_normalise=True, verbosity=0, compare_stats_api=True, compare_lib=True, mode="site", ): - np.set_printoptions(linewidth=500, precision=4) + # print("samples = ", samples, sample_sets) # print(ts.draw_text()) if verbosity > 1: print(ts.draw_text()) D1 = divergence_matrix( - ts, windows=windows, samples=samples, mode=mode, span_normalise=span_normalise + ts, + sample_sets=sample_sets, + samples=samples, + windows=windows, + mode=mode, + span_normalise=span_normalise, ) if compare_stats_api: - # Somethings like duplicate samples aren't worth hacking around for in - # stats API. D2 = stats_api_divergence_matrix( ts, windows=windows, samples=samples, + sample_sets=sample_sets, mode=mode, span_normalise=span_normalise, ) @@ -370,12 +433,24 @@ def check_divmat( np.testing.assert_allclose(D1, D2) assert D1.shape == D2.shape if compare_lib: + ids = None + if sample_sets is not None: + ids = sample_sets + if samples is not None: + ids = samples D3 = ts.divergence_matrix( - windows=windows, samples=samples, mode=mode, span_normalise=span_normalise + ids, + windows=windows, + mode=mode, + span_normalise=span_normalise, ) + # print() + # np.set_printoptions(linewidth=500, precision=4) + # print(D1) # print(D3) assert D1.shape == D3.shape np.testing.assert_allclose(D1, D3) + return D1 @@ -467,6 +542,25 @@ def test_single_tree_sequence_length_span_normalise(self, L): ) np.testing.assert_array_equal(D1, D2) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_diploid_individuals(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + ts = tsutil.insert_individuals(ts, ploidy=2) + D1 = check_divmat( + ts, + sample_sets=[ind.nodes for ind in ts.individuals()], + mode=mode, + ) + D2 = np.array([[2.0, 4.0], [4.0, 2.0]]) + np.testing.assert_array_equal(D1, D2) + @pytest.mark.parametrize("num_windows", [1, 2, 3, 5]) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_gap_at_end(self, num_windows, mode): @@ -524,14 +618,8 @@ def test_single_tree_mixed_non_sample_samples(self, mode): # 0 1 ts = tskit.Tree.generate_balanced(4).tree_sequence ts = tsutil.insert_branch_sites(ts) - D1 = check_divmat(ts, samples=[0, 5], mode=mode) - D2 = np.array( - [ - [0.0, 3.0], - [3.0, 0.0], - ] - ) - np.testing.assert_array_equal(D1, D2) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_SAMPLES"): + ts.divergence_matrix([0, 5], mode=mode) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_duplicate_samples(self, mode): @@ -544,7 +632,7 @@ def test_single_tree_duplicate_samples(self, mode): ts = tskit.Tree.generate_balanced(4).tree_sequence ts = tsutil.insert_branch_sites(ts) with pytest.raises(tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): - ts.divergence_matrix(samples=[0, 0, 1], mode=mode) + ts.divergence_matrix([0, 0, 1], mode=mode) @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_single_tree_multiroot(self, mode): @@ -831,14 +919,14 @@ def check( self, ts, windows=None, - samples=None, + sample_sets=None, num_threads=0, span_normalise=True, mode="branch", ): D1 = ts.divergence_matrix( + sample_sets, windows=windows, - samples=samples, num_threads=num_threads, mode=mode, span_normalise=span_normalise, @@ -846,11 +934,12 @@ def check( D2 = stats_api_divergence_matrix( ts, windows=windows, - samples=samples, + sample_sets=sample_sets, mode=mode, span_normalise=span_normalise, ) assert D1.shape == D2.shape + # np.set_printoptions(linewidth=500, precision=4) # print() # print(D1) # print(D2) @@ -866,7 +955,7 @@ def check( np.testing.assert_allclose(D1, D2, atol=atol) else: assert mode == "site" - np.testing.assert_array_equal(D1, D2) + np.testing.assert_allclose(D1, D2) @pytest.mark.parametrize("ts", get_example_tree_sequences()) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -877,7 +966,16 @@ def test_defaults(self, ts, mode): @pytest.mark.parametrize("mode", DIVMAT_MODES) def test_subset_samples(self, ts, mode): n = min(ts.num_samples, 2) - self.check(ts, samples=ts.samples()[:n], mode=mode) + self.check(ts, sample_sets=[[u] for u in ts.samples()[:n]], mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + @pytest.mark.parametrize("ploidy", [1, 2, 3]) + def test_ploidy_sample_sets(self, ts, mode, ploidy): + if ts.num_samples >= 2 * ploidy: + # Workaround limitations in the stats API + sample_sets = np.array_split(ts.samples(), ts.num_samples // ploidy) + self.check(ts, sample_sets=sample_sets, mode=mode) @pytest.mark.parametrize("ts", get_example_tree_sequences()) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -899,17 +997,25 @@ def test_threads_windows(self, ts, mode): class TestThreadsNoWindows: - def check(self, ts, num_threads, samples=None, mode=None): - D1 = ts.divergence_matrix(num_threads=0, samples=samples, mode=mode) - D2 = ts.divergence_matrix(num_threads=num_threads, samples=samples, mode=mode) + def check(self, ts, num_threads, samples=None, mode=None, span_normalise=True): + D1 = ts.divergence_matrix( + samples, num_threads=0, mode=mode, span_normalise=span_normalise + ) + D2 = ts.divergence_matrix( + samples, + num_threads=num_threads, + mode=mode, + span_normalise=span_normalise, + ) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) @pytest.mark.parametrize("mode", DIVMAT_MODES) - def test_all_trees(self, num_threads, mode): + @pytest.mark.parametrize("span_normalise", [True, False]) + def test_all_trees(self, num_threads, mode, span_normalise): ts = tsutil.all_trees_ts(4) assert ts.num_trees == 26 - self.check(ts, num_threads, mode=mode) + self.check(ts, num_threads, mode=mode, span_normalise=span_normalise) @pytest.mark.parametrize("samples", [None, [0, 1]]) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -936,11 +1042,9 @@ def test_simple_sims(self, n, num_threads, mode): class TestThreadsWindows: def check(self, ts, num_threads, *, windows, samples=None, mode=None): - D1 = ts.divergence_matrix( - num_threads=0, windows=windows, samples=samples, mode=mode - ) + D1 = ts.divergence_matrix(samples, num_threads=0, windows=windows, mode=mode) D2 = ts.divergence_matrix( - num_threads=num_threads, windows=windows, samples=samples, mode=mode + samples, num_threads=num_threads, windows=windows, mode=mode ) np.testing.assert_array_almost_equal(D1, D2) @@ -1155,3 +1259,177 @@ def test_simple_simulation(self): for j in range(var.num_alleles): a = A[offsets[j] : offsets[j + 1]] assert list(a) == list(allele_samples[j]) + + +class TestSampleSetParsing: + @pytest.mark.parametrize( + ["arg", "flattened", "sizes"], + [ + ([], [], []), + ([1], [1], [1]), + ([1, 2], [1, 2], [1, 1]), + ([[1, 2], [3, 4]], [1, 2, 3, 4], [2, 2]), + (((1, 2), (3, 4)), [1, 2, 3, 4], [2, 2]), + (np.array([[1, 2], [3, 4]]), [1, 2, 3, 4], [2, 2]), + (np.array([1, 2]), [1, 2], [1, 1]), + (np.array([1, 2], dtype=np.uint32), [1, 2], [1, 1]), + (array.array("i", [1, 2]), [1, 2], [1, 1]), + ([[1, 2], [3], [4]], [1, 2, 3, 4], [2, 1, 1]), + ([[1], [2]], [1, 2], [1, 1]), + ([[1, 1], [2]], [1, 1, 2], [2, 1]), + ], + ) + def test_good_args(self, arg, flattened, sizes): + f, s = tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + # print(f, s) + assert isinstance(f, np.ndarray) + assert f.dtype == np.int32 + assert isinstance(s, np.ndarray) + assert s.dtype == np.uint64 + np.testing.assert_array_equal(f, flattened) + np.testing.assert_array_equal(s, sizes) + + @pytest.mark.parametrize( + "arg", + [ + ["0", "1"], + ["0", 1], + [0, "1"], + [0, {"a": "b"}], + ], + ) + def test_nested_bad_types(self, arg): + with pytest.raises(TypeError): + tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + + @pytest.mark.parametrize( + "arg", + [ + [[0], [[0, 0]]], + [[[0, 0]], [0]], + np.array([[[0, 0], [0, 0]]]), + ], + ) + def test_nested_arrays(self, arg): + with pytest.raises(ValueError): + tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + + @pytest.mark.parametrize("arg", ["", "string", "1", "[1, 2]", b"", "1234"]) + def test_string_args(self, arg): + with pytest.raises(TypeError, match="ID specification cannot be"): + tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + + @pytest.mark.parametrize( + "arg", + [ + {}, + {"a": "b"}, + collections.Counter(), + ], + ) + def test_dict_args(self, arg): + with pytest.raises(TypeError, match="ID specification cannot be"): + tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + + @pytest.mark.parametrize( + "arg", + [ + 0, + {0: 1}, + None, + {"a": "b"}, + np.array([1.1]), + ], + ) + def test_bad_arg_types(self, arg): + with pytest.raises(TypeError): + tskit.TreeSequence._parse_stat_matrix_sample_sets(arg) + + +class TestGeneticRelatednessMatrix: + def check(self, ts, mode, *, sample_sets=None, windows=None, span_normalise=True): + G1 = stats_api_genetic_relatedness_matrix( + ts, + mode=mode, + sample_sets=sample_sets, + windows=windows, + span_normalise=span_normalise, + ) + G2 = ts.genetic_relatedness_matrix( + mode=mode, + sample_sets=sample_sets, + windows=windows, + span_normalise=span_normalise, + ) + np.testing.assert_array_almost_equal(G1, G2) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + self.check(ts, mode) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_sample_sets(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + with pytest.raises(ValueError, match="2888"): + self.check(ts, mode, sample_sets=[[0, 1], [2, 3]]) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_single_samples(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + self.check(ts, mode, sample_sets=[[0], [1]]) + self.check(ts, mode, sample_sets=[[0], [2]]) + self.check(ts, mode, sample_sets=[[0], [1], [2]]) + + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_single_tree_windows(self, mode): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4).tree_sequence + ts = tsutil.insert_branch_sites(ts) + self.check(ts, mode, windows=[0, 0.5, 1]) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + def test_suite_defaults(self, ts, mode): + self.check(ts, mode=mode) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + @pytest.mark.parametrize("span_normalise", [True, False]) + def test_suite_span_normalise(self, ts, mode, span_normalise): + self.check(ts, mode=mode, span_normalise=span_normalise) + + @pytest.mark.skip("fix sample sets #2888") + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("mode", DIVMAT_MODES) + @pytest.mark.parametrize("num_sets", [2]) # [[2, 3, 4, 5]) + def test_suite_sample_sets(self, ts, mode, num_sets): + if ts.num_samples >= num_sets: + sample_sets = np.array_split(ts.samples(), num_sets) + self.check(ts, sample_sets=sample_sets, mode=mode) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 16676a240c..23f240006c 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1549,26 +1549,41 @@ def test_divergence_matrix(self): n = 10 ts = self.get_example_tree_sequence(n, random_seed=12) windows = [0, ts.get_sequence_length()] - D = ts.divergence_matrix(windows) + ids = np.arange(n, dtype=np.int32) + sizes = np.ones(n, dtype=np.uint64) + D = ts.divergence_matrix(windows, sizes, ids) assert D.shape == (1, n, n) - D = ts.divergence_matrix(windows, samples=[0, 1]) + D = ts.divergence_matrix(windows, sample_set_sizes=[1, 1], sample_sets=[0, 1]) assert D.shape == (1, 2, 2) - D = ts.divergence_matrix(windows, samples=[0, 1], span_normalise=True) + D = ts.divergence_matrix( + windows, sample_set_sizes=[1, 1], sample_sets=[0, 1], span_normalise=True + ) assert D.shape == (1, 2, 2) + + for bad_node in [-1, -2, 1000]: + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.divergence_matrix(windows, [1, 1], [0, bad_node]) + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + ts.divergence_matrix(windows, [1, 2], [0, 1]) + with pytest.raises(ValueError, match="Overflow"): + ts.divergence_matrix(windows, [-1, 2], [0]) + with pytest.raises(TypeError, match="str"): - ts.divergence_matrix(windows, span_normalise="xdf") + ts.divergence_matrix(windows, sizes, ids, span_normalise="xdf") with pytest.raises(TypeError): ts.divergence_matrix(windoze=[0, 1]) with pytest.raises(ValueError, match="at least 2"): - ts.divergence_matrix(windows=[0]) + ts.divergence_matrix( + [0], + sizes, + ids, + ) with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): - ts.divergence_matrix(windows=[-1, 0, 1]) - with pytest.raises(ValueError): - ts.divergence_matrix(windows=[0, 1], samples="sdf") + ts.divergence_matrix([-1, 0, 1], sizes, ids) with pytest.raises(ValueError, match="Unrecognised stats mode"): - ts.divergence_matrix(windows=[0, 1], mode="sdf") + ts.divergence_matrix([0, 1], sizes, ids, mode="sdf") with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): - ts.divergence_matrix(windows=[0, 1], mode="node") + ts.divergence_matrix([0, 1], sizes, ids, mode="node") def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 1db0333fbb..0a59824788 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -7876,6 +7876,38 @@ def worker(sub_windows): concurrent.futures.wait(futures) return np.vstack([future.result() for future in futures]) + @staticmethod + def _parse_stat_matrix_sample_sets(ids): + """ + Returns a flattened list of sets of IDs. If ids is a 1D list, + interpret as n one-element sets. Otherwise, it must be a sequence + of ID lists. + """ + id_dtype = np.int32 + size_dtype = np.uint64 + # Exclude some types that could be specified accidentally, and + # we may want to reserve for future use. + if isinstance(ids, (str, bytes, collections.abc.Mapping, numbers.Number)): + raise TypeError(f"ID specification cannot be a {type(ids)}") + if len(ids) == 0: + return np.array([], dtype=id_dtype), np.array([], dtype=size_dtype) + if isinstance(ids[0], numbers.Number): + # Interpret as a 1D array + flat = util.safe_np_int_cast(ids, id_dtype) + sizes = np.ones(len(flat), dtype=size_dtype) + else: + set_lists = [] + sizes = [] + for id_list in ids: + a = util.safe_np_int_cast(id_list, id_dtype) + if len(a.shape) != 1: + raise ValueError("ID sets must be 1D integer arrays") + set_lists.append(a) + sizes.append(len(a)) + flat = np.hstack(set_lists) + sizes = np.array(sizes, dtype=size_dtype) + return flat, sizes + # def divergence_matrix(self, sample_sets, windows=None, mode="site"): # """ # Finds the mean divergence between pairs of samples from each set of @@ -7912,6 +7944,11 @@ def worker(sub_windows): # NOTE: see older definition of divmat here, which may be useful when documenting # this function. See https://github.com/tskit-dev/tskit/issues/2781 + # NOTE for documentation of sample_sets. We *must* use samples currently because + # the normalisation for non-sample nodes is tricky. Do we normalise by the + # total span of the ts where the node is 'present' in the tree? We avoid this + # by insisting on sample nodes. + # NOTE for documentation of num_threads. Need to explain that the # its best to think of as the number of background *worker* threads. # default is to run without any worker threads. If you want to run @@ -7919,9 +7956,9 @@ def worker(sub_windows): def divergence_matrix( self, + sample_sets=None, *, windows=None, - samples=None, num_threads=0, mode=None, span_normalise=True, @@ -7930,12 +7967,22 @@ def divergence_matrix( windows = self.parse_windows(windows) mode = "site" if mode is None else mode + if sample_sets is None: + sample_sets = self.samples() + flattened_samples = self.samples() + sample_set_sizes = np.ones(len(sample_sets), dtype=np.uint32) + else: + flattened_samples, sample_set_sizes = self._parse_stat_matrix_sample_sets( + sample_sets + ) + # FIXME this logic should be merged into __run_windowed_stat if # we generalise the num_threads argument to all stats. if num_threads <= 0: D = self._ll_tree_sequence.divergence_matrix( windows, - samples=samples, + sample_sets=flattened_samples, + sample_set_sizes=sample_set_sizes, mode=mode, span_normalise=span_normalise, ) @@ -7944,7 +7991,8 @@ def divergence_matrix( D = self._parallelise_divmat_by_window( windows, num_threads, - samples=samples, + sample_sets=flattened_samples, + sample_set_sizes=sample_set_sizes, mode=mode, span_normalise=span_normalise, ) @@ -7952,7 +8000,8 @@ def divergence_matrix( D = self._parallelise_divmat_by_tree( num_threads, span_normalise=span_normalise, - samples=samples, + sample_sets=flattened_samples, + sample_set_sizes=sample_set_sizes, mode=mode, ) @@ -8073,6 +8122,50 @@ def genetic_relatedness( return out + def genetic_relatedness_matrix( + self, + sample_sets=None, + *, + windows=None, + num_threads=0, + mode=None, + span_normalise=True, + ): + D = self.divergence_matrix( + sample_sets, + windows=windows, + num_threads=num_threads, + mode=mode, + span_normalise=span_normalise, + ) + + # FIXME remove this when sample sets bug has been fixed. + # https://github.com/tskit-dev/tskit/issues/2888 + if sample_sets is not None: + if any(len(ss) > 1 for ss in sample_sets): + raise ValueError( + "Only single entry sample sets allowed for now." + " See https://github.com/tskit-dev/tskit/issues/2888" + ) + + def _normalise(B): + if len(B) == 0: + return B + K = B + np.mean(B) + y = np.mean(B, axis=0) + X = y[:, np.newaxis] + y[np.newaxis, :] + K -= X + # FIXME this factor of 2 works for single-sample sample-sets, but not + # otherwise. https://github.com/tskit-dev/tskit/issues/2888 + return K / -2 + + if windows is None: + return _normalise(D) + else: + for j in range(D.shape[0]): + D[j] = _normalise(D[j]) + return D + def genetic_relatedness_weighted( self, W,