9
9
from typing_extensions import Literal
10
10
from xarray import DataArray , Dataset
11
11
12
+ from sgkit import variables
13
+
12
14
from ..typing import ArrayLike , DType , RandomStateType
13
15
from ..utils import conditional_merge_datasets
14
16
from .aggregation import count_call_alleles
@@ -93,8 +95,10 @@ def pca_transform(
93
95
""" Apply PCA estimator to new data """
94
96
AC = _allele_counts (ds , variable , check_missing = check_missing )
95
97
projection = est .transform (da .asarray (AC ).T )
96
- new_ds = Dataset ({"sample_pca_projection" : (("samples" , "components" ), projection )})
97
- return conditional_merge_datasets (ds , new_ds , merge )
98
+ new_ds = Dataset (
99
+ {variables .sample_pca_projection : (("samples" , "components" ), projection )}
100
+ )
101
+ return conditional_merge_datasets (ds , variables .validate (new_ds ), merge )
98
102
99
103
100
104
def _get (est : BaseEstimator , attr : str , fn : Any = lambda v : v ) -> Optional [ArrayLike ]:
@@ -109,25 +113,25 @@ def _get(est: BaseEstimator, attr: str, fn: Any = lambda v: v) -> Optional[Array
109
113
def pca_stats (ds : Dataset , est : BaseEstimator , * , merge : bool = True ) -> Dataset :
110
114
""" Extract attributes from PCA estimator """
111
115
new_ds = {
112
- " sample_pca_component" : (
116
+ variables . sample_pca_component : (
113
117
("variants" , "components" ),
114
118
_get (est , "components_" , fn = lambda v : v .T ),
115
119
),
116
- " sample_pca_explained_variance" : (
120
+ variables . sample_pca_explained_variance : (
117
121
"components" ,
118
122
_get (est , "explained_variance_" ),
119
123
),
120
- " sample_pca_explained_variance_ratio" : (
124
+ variables . sample_pca_explained_variance_ratio : (
121
125
"components" ,
122
126
_get (est , "explained_variance_ratio_" ),
123
127
),
124
128
}
125
129
new_ds = Dataset ({k : v for k , v in new_ds .items () if v [1 ] is not None })
126
130
if "sample_pca_component" in new_ds and "sample_pca_explained_variance" in new_ds :
127
- new_ds [" sample_pca_loading" ] = new_ds ["sample_pca_component" ] * np . sqrt (
128
- new_ds [ "sample_pca_explained_variance" ]
129
- )
130
- return conditional_merge_datasets (ds , new_ds , merge )
131
+ new_ds [variables . sample_pca_loading ] = new_ds [
132
+ variables . sample_pca_component
133
+ ] * np . sqrt ( new_ds [ variables . sample_pca_explained_variance ] )
134
+ return conditional_merge_datasets (ds , variables . validate ( new_ds ) , merge )
131
135
132
136
133
137
def pca (
0 commit comments