Skip to content

Commit 4ef0cbb

Browse files
committed
Suggested changes
1 parent dc03424 commit 4ef0cbb

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

setup.cfg

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,9 @@ ignore_missing_imports = True
101101
ignore_missing_imports = True
102102
[mypy-dask.*]
103103
ignore_missing_imports = True
104-
<<<<<<< HEAD
105104
[mypy-fsspec.*]
106-
=======
105+
ignore_missing_imports = True
107106
[mypy-dask_ml.*]
108-
>>>>>>> PCA implementation #95
109107
ignore_missing_imports = True
110108
[mypy-numpy.*]
111109
ignore_missing_imports = True

sgkit/stats/pca.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from typing_extensions import Literal
1010
from xarray import DataArray, Dataset
1111

12+
from sgkit import variables
13+
1214
from ..typing import ArrayLike, DType, RandomStateType
1315
from ..utils import conditional_merge_datasets
1416
from .aggregation import count_call_alleles
@@ -93,8 +95,10 @@ def pca_transform(
9395
""" Apply PCA estimator to new data """
9496
AC = _allele_counts(ds, variable, check_missing=check_missing)
9597
projection = est.transform(da.asarray(AC).T)
96-
new_ds = Dataset({"sample_pca_projection": (("samples", "components"), projection)})
97-
return conditional_merge_datasets(ds, new_ds, merge)
98+
new_ds = Dataset(
99+
{variables.sample_pca_projection: (("samples", "components"), projection)}
100+
)
101+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
98102

99103

100104
def _get(est: BaseEstimator, attr: str, fn: Any = lambda v: v) -> Optional[ArrayLike]:
@@ -109,25 +113,25 @@ def _get(est: BaseEstimator, attr: str, fn: Any = lambda v: v) -> Optional[Array
109113
def pca_stats(ds: Dataset, est: BaseEstimator, *, merge: bool = True) -> Dataset:
110114
""" Extract attributes from PCA estimator """
111115
new_ds = {
112-
"sample_pca_component": (
116+
variables.sample_pca_component: (
113117
("variants", "components"),
114118
_get(est, "components_", fn=lambda v: v.T),
115119
),
116-
"sample_pca_explained_variance": (
120+
variables.sample_pca_explained_variance: (
117121
"components",
118122
_get(est, "explained_variance_"),
119123
),
120-
"sample_pca_explained_variance_ratio": (
124+
variables.sample_pca_explained_variance_ratio: (
121125
"components",
122126
_get(est, "explained_variance_ratio_"),
123127
),
124128
}
125129
new_ds = Dataset({k: v for k, v in new_ds.items() if v[1] is not None})
126130
if "sample_pca_component" in new_ds and "sample_pca_explained_variance" in new_ds:
127-
new_ds["sample_pca_loading"] = new_ds["sample_pca_component"] * np.sqrt(
128-
new_ds["sample_pca_explained_variance"]
129-
)
130-
return conditional_merge_datasets(ds, new_ds, merge)
131+
new_ds[variables.sample_pca_loading] = new_ds[
132+
variables.sample_pca_component
133+
] * np.sqrt(new_ds[variables.sample_pca_explained_variance])
134+
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
131135

132136

133137
def pca(

sgkit/variables.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,41 @@ def _check_field(
247247
ArrayLikeSpec("sample_pcs", ndim=2, kind="f")
248248
)
249249
"""Sample PCs (PCxS)."""
250+
sample_pca_component, sample_pca_component_spec = SgkitVariables.register_variable(
251+
ArrayLikeSpec("sample_pca_component", ndim=2, kind="f")
252+
)
253+
"""Principal axes defined as eigenvectors for sample covariance matrix.
254+
In the context of SVD, these are equivalent to the right singular vectors in
255+
the decomposition of a (N, M) matrix., i.e. ``dask_ml.decomposition.TruncatedSVD.components_``."""
256+
(
257+
sample_pca_explained_variance,
258+
sample_pca_explained_variance_spec,
259+
) = SgkitVariables.register_variable(
260+
ArrayLikeSpec("sample_pca_explained_variance", ndim=1, kind="f")
261+
)
262+
"""Variance explained by each principal component. These values are equivalent
263+
to eigenvalues that result from the eigendecomposition of a (N, M) matrix,
264+
i.e. ``dask_ml.decomposition.TruncatedSVD.explained_variance_``."""
265+
(
266+
sample_pca_explained_variance_ratio,
267+
sample_pca_explained_variance_ratio_spec,
268+
) = SgkitVariables.register_variable(
269+
ArrayLikeSpec("sample_pca_explained_variance_ratio", ndim=1, kind="f")
270+
)
271+
"""Ratio of variance explained to total variance for each principal component,
272+
i.e. ``dask_ml.decomposition.TruncatedSVD.explained_variance_ratio_``."""
273+
sample_pca_loading, sample_pca_loading_spec = SgkitVariables.register_variable(
274+
ArrayLikeSpec("sample_pca_loading", ndim=2, kind="f")
275+
)
276+
"""PCA loadings defined as principal axes scaled by square root of eigenvalues.
277+
These values can also be interpreted as the correlation between the original variables
278+
and unit-scaled principal axes."""
279+
sample_pca_projection, sample_pca_projection_spec = SgkitVariables.register_variable(
280+
ArrayLikeSpec("sample_pca_projection", ndim=2, kind="f")
281+
)
282+
"""Projection of samples onto principal axes. This array is commonly
283+
referred to as "scores" or simply "principal components (PCs)" for a set of samples."""
284+
250285
stat_Fst, stat_Fst_spec = SgkitVariables.register_variable(
251286
ArrayLikeSpec("stat_Fst", ndim=2, kind="f")
252287
)

0 commit comments

Comments
 (0)