Skip to content

Commit 4abc7f7

Browse files
Add coefficient_dtype property to Basis to enforce input and output types for evaluate and evaluate_t (#805)
* rewrite simplified coefficient_dtype * add in for tests * set coefficient_dtype in Basis __init__ * spacing
1 parent 8417bca commit 4abc7f7

File tree

5 files changed

+34
-29
lines changed

5 files changed

+34
-29
lines changed

src/aspire/basis/basis.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
4848
raise NotImplementedError(
4949
"Currently only implemented for float32 and float64 types"
5050
)
51+
# dtype of coefficients is the same as self.dtype for real bases
52+
# subclasses with complex coefficients override this attribute
53+
self.coefficient_dtype = self.dtype
5154

5255
self._build()
5356

@@ -86,10 +89,10 @@ def evaluate(self, v):
8689
This is an Image or a Volume object containing one image/volume for each
8790
coefficient vector, and of size `self.sz`.
8891
"""
89-
if v.dtype != self.dtype:
92+
if v.dtype != self.coefficient_dtype:
9093
logger.warning(
9194
f"{self.__class__.__name__}::evaluate"
92-
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
95+
f" Inconsistent dtypes v: {v.dtype} self coefficient dtype: {self.coefficient_dtype}"
9396
)
9497

9598
# Flatten stack, ndim is wrt Basis (2 or 3)
@@ -190,6 +193,12 @@ def expand(self, x):
190193
if isinstance(x, Image) or isinstance(x, Volume):
191194
x = x.asnumpy()
192195

196+
if x.dtype != self.dtype:
197+
logger.warning(
198+
f"{self.__class__.__name__}::expand"
199+
f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}"
200+
)
201+
193202
# check that last ndim values of input shape match
194203
# the shape of this basis
195204
assert (
@@ -212,7 +221,7 @@ def expand(self, x):
212221

213222
# number of image samples
214223
n_data = x.shape[0]
215-
v = np.zeros((n_data, self.count), dtype=x.dtype)
224+
v = np.zeros((n_data, self.count), dtype=self.coefficient_dtype)
216225

217226
for isample in range(0, n_data):
218227
b = self.evaluate_t(self._cls(x[isample])).T

src/aspire/basis/fpswf_2d.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from aspire.basis.basis_utils import lgwt, t_x_mat, t_x_mat_dot
99
from aspire.basis.pswf_2d import PSWFBasis2D
10-
from aspire.image import Image
1110
from aspire.nufft import nufft
1211
from aspire.numeric import fft, xp
1312
from aspire.utils import complex_type
@@ -107,21 +106,13 @@ def _precomp(self):
107106
self.n_max = n_max
108107
self.size_x = len(self._disk_mask)
109108

110-
def evaluate_t(self, images):
109+
def _evaluate_t(self, images):
111110
"""
112111
Evaluate coefficient vectors in PSWF basis using the fast method.
113112
114113
:param images: Image stack in the standard 2D coordinate basis.
115114
:return: Coefficient array in the PSWF basis.
116115
"""
117-
118-
if not isinstance(images, Image):
119-
logger.warning(
120-
"FPSWFBasis2D.evaluate_t expects Image instance,"
121-
" attempting conversion."
122-
)
123-
images = Image(images)
124-
125116
# Construct array with zeros outside mask
126117
images_disk = np.zeros(images.shape, dtype=images.dtype)
127118
images_disk[:, self._disk_mask] = images[:, self._disk_mask]

src/aspire/basis/polar_2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from aspire.basis import Basis
66
from aspire.nufft import anufft, nufft
7+
from aspire.utils import complex_type
78

89
logger = logging.getLogger(__name__)
910

@@ -34,6 +35,9 @@ def __init__(self, size, nrad=None, ntheta=None, dtype=np.float32):
3435

3536
super().__init__(size, dtype=dtype)
3637

38+
# this basis has complex coefficients
39+
self.coefficient_dtype = complex_type(self.dtype)
40+
3741
def _build(self):
3842
"""
3943
Build the internal data structure to 2D polar Fourier grid

src/aspire/basis/pswf_2d.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
t_x_mat,
1313
)
1414
from aspire.basis.pswf_utils import BNMatrix
15-
from aspire.image import Image
1615
from aspire.utils import complex_type
1716

1817
logger = logging.getLogger(__name__)
@@ -59,6 +58,9 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32):
5958
self.beta = beta
6059
super().__init__(size, dtype=dtype)
6160

61+
# this basis has complex coefficients
62+
self.coefficient_dtype = complex_type(self.dtype)
63+
6264
def _build(self):
6365
"""
6466
Build internal data structures for the direct 2D PSWF method
@@ -149,27 +151,19 @@ def _generate_samples(self):
149151
# the column dimension of samples_conj_transpose is the number of basis coefficients
150152
self.count = self.samples_conj_transpose.shape[1]
151153

152-
def evaluate_t(self, images):
154+
def _evaluate_t(self, images):
153155
"""
154156
Evaluate coefficient vectors in PSWF basis using the direct method
155157
156158
:param images: coefficient array in the standard 2D coordinate basis
157159
to be evaluated.
158160
:return: The evaluation of the coefficient array in the PSWF basis.
159161
"""
160-
161-
if not isinstance(images, Image):
162-
logger.warning(
163-
"FPSWFBasis2D.evaluate_t expects Image instance,"
164-
" attempting conversion."
165-
)
166-
images = Image(images)
167-
168162
flattened_images = images[:, self._disk_mask]
169163

170164
return flattened_images @ self.samples_conj_transpose
171165

172-
def evaluate(self, coefficients):
166+
def _evaluate(self, coefficients):
173167
"""
174168
Evaluate coefficients in standard 2D coordinate basis from those in PSWF basis
175169
@@ -196,7 +190,7 @@ def evaluate(self, coefficients):
196190
)
197191
images[:, self._disk_mask] = np.real(flatten_images)
198192

199-
return Image(images)
193+
return images
200194

201195
def _init_pswf_func2d(self, c, eps):
202196
"""

tests/_basis_util.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,26 +138,33 @@ def getClass(self):
138138
return Volume
139139

140140
def testEvaluate(self):
141-
# evaluate should take a NumPy array and return an Image/Volume
141+
# evaluate should take a NumPy array of type basis.coefficient_dtype
142+
# and return an Image/Volume
142143
_class = self.getClass()
143-
result = self.basis.evaluate(np.zeros((self.basis.count), dtype=self.dtype))
144+
result = self.basis.evaluate(
145+
np.zeros((self.basis.count), dtype=self.basis.coefficient_dtype)
146+
)
144147
self.assertTrue(isinstance(result, _class))
145148

146149
def testEvaluate_t(self):
147-
# evaluate_t should take an Image/Volume and return a NumPy array
150+
# evaluate_t should take an Image/Volume and return a NumPy array of type
151+
# basis.coefficient_dtype
148152
_class = self.getClass()
149153
result = self.basis.evaluate_t(
150154
_class(np.zeros((self.L,) * self.basis.ndim, dtype=self.dtype))
151155
)
152156
self.assertTrue(isinstance(result, np.ndarray))
157+
self.assertEqual(result.dtype, self.basis.coefficient_dtype)
153158

154159
def testExpand(self):
155160
_class = self.getClass()
156-
# expand should take an Image/Volume and return a NumPy array
161+
# expand should take an Image/Volume and return a NumPy array of type
162+
# basis.coefficient_dtype
157163
result = self.basis.expand(
158164
_class(np.zeros((self.L,) * self.basis.ndim, dtype=self.dtype))
159165
)
160166
self.assertTrue(isinstance(result, np.ndarray))
167+
self.assertEqual(result.dtype, self.basis.coefficient_dtype)
161168

162169
def testInitWithIntSize(self):
163170
# make sure we can instantiate with just an int as a shortcut

0 commit comments

Comments
 (0)