@@ -48,6 +48,9 @@ def __init__(self, size, ell_max=None, dtype=np.float32):
48
48
raise NotImplementedError (
49
49
"Currently only implemented for float32 and float64 types"
50
50
)
51
+ # dtype of coefficients is the same as self.dtype for real bases
52
+ # subclasses with complex coefficients override this attribute
53
+ self .coefficient_dtype = self .dtype
51
54
52
55
self ._build ()
53
56
@@ -86,10 +89,10 @@ def evaluate(self, v):
86
89
This is an Image or a Volume object containing one image/volume for each
87
90
coefficient vector, and of size `self.sz`.
88
91
"""
89
- if v .dtype != self .dtype :
92
+ if v .dtype != self .coefficient_dtype :
90
93
logger .warning (
91
94
f"{ self .__class__ .__name__ } ::evaluate"
92
- f" Inconsistent dtypes v: { v .dtype } self: { self .dtype } "
95
+ f" Inconsistent dtypes v: { v .dtype } self coefficient dtype : { self .coefficient_dtype } "
93
96
)
94
97
95
98
# Flatten stack, ndim is wrt Basis (2 or 3)
@@ -190,6 +193,12 @@ def expand(self, x):
190
193
if isinstance (x , Image ) or isinstance (x , Volume ):
191
194
x = x .asnumpy ()
192
195
196
+ if x .dtype != self .dtype :
197
+ logger .warning (
198
+ f"{ self .__class__ .__name__ } ::expand"
199
+ f" Inconsistent dtypes x: { x .dtype } self: { self .dtype } "
200
+ )
201
+
193
202
# check that last ndim values of input shape match
194
203
# the shape of this basis
195
204
assert (
@@ -212,7 +221,7 @@ def expand(self, x):
212
221
213
222
# number of image samples
214
223
n_data = x .shape [0 ]
215
- v = np .zeros ((n_data , self .count ), dtype = x . dtype )
224
+ v = np .zeros ((n_data , self .count ), dtype = self . coefficient_dtype )
216
225
217
226
for isample in range (0 , n_data ):
218
227
b = self .evaluate_t (self ._cls (x [isample ])).T
0 commit comments