diff --git a/gallery/tutorials/image_expansion.py b/gallery/tutorials/image_expansion.py index 6cd3a05781..d202973492 100644 --- a/gallery/tutorials/image_expansion.py +++ b/gallery/tutorials/image_expansion.py @@ -143,13 +143,14 @@ logger.info("Finish reconstruction of images from direct PSWF expansion coefficients.") # Calculate the mean value of maximum differences between direct PSWF estimated images and original images -pswf_meanmax = np.mean(np.max(abs(pswf_images - org_images), axis=2)) +diff = (pswf_images - org_images).asnumpy() +pswf_meanmax = np.mean(np.max(abs(diff), axis=2)) logger.info( f"Mean value of maximum differences between PSWF estimated images and original images: {pswf_meanmax}" ) # Calculate the normalized RMSE of the estimated images -pswf_nrmse_ims = anorm(pswf_images - org_images) / anorm(org_images) +pswf_nrmse_ims = anorm(diff) / anorm(org_images) logger.info(f"PSWF Estimated images normalized RMSE: {pswf_nrmse_ims}") # plot the first images using the direct PSWF method @@ -186,13 +187,14 @@ logger.info("Finish reconstruction of images from fast PSWF expansion coefficients.") # Calculate mean value of maximum differences between the fast PSWF estimated images and the original images -fpswf_meanmax = np.mean(np.max(abs(fpswf_images - org_images), axis=0)) +diff = (fpswf_images - org_images).asnumpy() +fpswf_meanmax = np.mean(np.max(abs(diff), axis=0)) logger.info( f"Mean value of maximum differences between FPSWF estimated images and original images: {fpswf_meanmax}" ) # Calculate the normalized RMSE of the estimated images -fpswf_nrmse_ims = anorm(fpswf_images - org_images) / anorm(org_images) +fpswf_nrmse_ims = anorm(diff) / anorm(org_images) logger.info(f"FPSWF Estimated images normalized RMSE: {fpswf_nrmse_ims}") # plot the first images using the fast PSWF method @@ -264,3 +266,5 @@ plt.imshow(np.real(org_images[0] - fpswf_images[0]), cmap="gray") plt.title("Differences") plt.tight_layout() + +plt.show() diff --git a/src/aspire/basis/fpswf_2d.py b/src/aspire/basis/fpswf_2d.py index 8a82886df6..240ed160c4 100644 --- a/src/aspire/basis/fpswf_2d.py +++ b/src/aspire/basis/fpswf_2d.py @@ -1,13 +1,13 @@ import logging import numpy as np -from numpy import pi from numpy.linalg import lstsq from scipy.optimize import least_squares from scipy.special import jn from aspire.basis.basis_utils import lgwt, t_x_mat, t_x_mat_dot from aspire.basis.pswf_2d import PSWFBasis2D +from aspire.image import Image from aspire.nufft import nufft from aspire.numeric import fft, xp from aspire.utils import complex_type @@ -86,9 +86,8 @@ def _precomp(self): self.quad_rule_radial_wts = e self.num_angular_pts = f - # pre computing variables for forward - us_fft_pts = np.column_stack((self.quad_rule_pts_y, self.quad_rule_pts_x)) - us_fft_pts = self.bandlimit / (self.rcut * np.pi * 2) * us_fft_pts # for pynfft + us_fft_pts = np.row_stack((self.quad_rule_pts_x, self.quad_rule_pts_y)) + us_fft_pts = self.bandlimit / self.rcut * us_fft_pts ( blk_r, num_angular_pts, @@ -109,67 +108,30 @@ def _precomp(self): def evaluate_t(self, images): """ - Evaluate coefficient vectors in PSWF basis using the fast method + Evaluate coefficient vectors in PSWF basis using the fast method. - :param images: coefficient array in the standard 2D coordinate basis - to be evaluated. - :return : The evaluation of the coefficient array in the PSWF basis. + :param images: Image stack in the standard 2D coordinate basis. + :return : Coefficient array in the PSWF basis. """ - images = np.moveaxis(images, 0, -1) # RCOPT - - # start and finish are for the threads option in the future - images_shape = images.shape - start = 0 - - if len(images_shape) == 3: - # if we got several images - finish = images_shape[2] - else: - # else we got only one image - images_shape = images_shape + (1,) - images = images[..., np.newaxis] - finish = 1 - images_disk = np.zeros(images.shape, dtype=images.dtype, order="F") - images_disk[self._disk_mask, :] = images[self._disk_mask, :] - nfft_res = self._compute_nfft_potts(images_disk, start, finish) - coefficients = self._pswf_integration(nfft_res) - - return coefficients.T # RCOPT - - def evaluate(self, coefficients): - """ - Evaluate coefficients in standard 2D coordinate basis from those in PSWF basis - - :param coefficients: A coefficient vector (or an array of coefficient vectors) - in PSWF basis to be evaluated. - :return : The evaluation of the coefficient vector(s) in standard 2D - coordinate basis. - """ + if not isinstance(images, Image): + logger.warning( + "FPSWFBasis2D.evaluate_t expects Image instance," + " attempting conversion." + ) + images = Image(images) - coefficients = coefficients.T # RCOPT + # Construct array with zeros outside mask + images_disk = np.zeros(images.shape, dtype=images.dtype) + images_disk[:, self._disk_mask] = images[:, self._disk_mask] - # if we got only one vector - if len(coefficients.shape) == 1: - coefficients = coefficients.reshape((len(coefficients), 1)) + # Invoke nufft with the `many` plan (reuse plan/points) + nfft_res = nufft(images_disk, self.us_fft_pts) - angular_is_zero = np.absolute(self.ang_freqs) == 0 - flatten_images = self.samples[:, angular_is_zero].dot( - coefficients[angular_is_zero] - ) + ( - 2.0 - * np.real( - self.samples[:, ~angular_is_zero].dot(coefficients[~angular_is_zero]) - ) - ) + # Accumulate coefficients + coefficients = self._pswf_integration(nfft_res) - n_images = int(flatten_images.shape[1]) - images = np.zeros((self._image_height, self._image_height, n_images)).astype( - complex_type(self.dtype) - ) - images[self._disk_mask, :] = flatten_images - # TODO: no need to switch x and y any more, need to make consistent with direct method - return np.real(images).T # RCOPT + return coefficients def _generate_pswf_quad( self, n, bandlimit, phi_approximate_error, lambda_max, epsilon @@ -360,33 +322,19 @@ def _pswf_integration_sub_routine(self): for i in range(n_max): blk_r[i] = ( temp_const - * self.pswf_radial_quad[ - :, indices_for_n[i] + np.arange(numel_for_n[i]) - ].T + * self.pswf_radial_quad[indices_for_n[i] + np.arange(numel_for_n[i]), :] ) return blk_r, num_angular_pts, r_quad_indices, numel_for_n, indices_for_n, n_max - def _compute_nfft_potts(self, images, start, finish): - """ - Perform NuFFT transform for images in rectangular coordinates - """ - x = self.us_fft_pts - num_images = finish - start - - m = x.shape[0] - - images_nufft = np.zeros((m, num_images), dtype=complex_type(self.dtype)) - for i in range(start, finish): - images_nufft[:, i - start] = nufft(images[..., i], 2 * pi * x.T) - - return images_nufft - def _pswf_integration(self, images_nufft): """ Perform integration part for rotational invariant property. """ - num_images = images_nufft.shape[1] + # Handle both singleton and stacks + images_nufft = np.atleast_2d(images_nufft) + num_images = images_nufft.shape[0] + n_max_float = float(self.n_max) / 2 r_n_eval_mat = np.zeros( (len(self.radial_quad_pts), self.n_max, num_images), @@ -395,10 +343,10 @@ def _pswf_integration(self, images_nufft): for i in range(len(self.radial_quad_pts)): curr_r_mat = images_nufft[ + :, self.r_quad_indices[i] : self.r_quad_indices[i] + self.num_angular_pts[i], - :, - ] + ].T curr_r_mat = np.concatenate((curr_r_mat, np.conj(curr_r_mat))) fft_plan = xp.asnumpy(fft.fft(xp.asarray(curr_r_mat), axis=0)) angular_eval = fft_plan * self.quad_rule_radial_wts[i] @@ -412,12 +360,12 @@ def _pswf_integration(self, images_nufft): (len(self.radial_quad_pts) * self.n_max, num_images), order="F" ) coeff_vec_quad = np.zeros( - (len(self.ang_freqs), num_images), dtype=complex_type(self.dtype) + (num_images, len(self.ang_freqs)), dtype=complex_type(self.dtype) ) - m = len(self.pswf_radial_quad) + m = self.pswf_radial_quad.shape[1] for i in range(self.n_max): coeff_vec_quad[ - self.indices_for_n[i] + np.arange(self.numel_for_n[i]), : - ] = np.dot(self.blk_r[i], r_n_eval_mat[i * m : (i + 1) * m, :]) + :, self.indices_for_n[i] + np.arange(self.numel_for_n[i]) + ] = np.dot(self.blk_r[i], r_n_eval_mat[i * m : (i + 1) * m, :]).T return coeff_vec_quad diff --git a/src/aspire/basis/pswf_2d.py b/src/aspire/basis/pswf_2d.py index 5acc3470f5..67ae580157 100644 --- a/src/aspire/basis/pswf_2d.py +++ b/src/aspire/basis/pswf_2d.py @@ -12,6 +12,7 @@ t_x_mat, ) from aspire.basis.pswf_utils import BNMatrix +from aspire.image import Image from aspire.utils import complex_type logger = logging.getLogger(__name__) @@ -93,7 +94,6 @@ def _generate_grid(self): self._theta_disk = np.angle(x + 1j * y) self._image_height = len(x_1d_grid) self._disk_mask = points_in_disk - self._disk_mask_vec = points_in_disk.reshape(self._image_height ** 2) def _precomp(self): """ @@ -134,7 +134,7 @@ def _generate_samples(self): alpha_all.extend(alpha[:n_end]) m += 1 - self.alpha_nn = np.array(alpha_all) + self.alpha_nn = np.array(alpha_all).reshape(-1, 1) self.max_ns = max_ns self.samples = self._evaluate_pswf2d_all(self._r_disk, self._theta_disk, max_ns) @@ -153,48 +153,46 @@ def evaluate_t(self, images): to be evaluated. :return : The evaluation of the coefficient array in the PSWF basis. """ - images = images.T # RCOPT - images_shape = images.shape + if not isinstance(images, Image): + logger.warning( + "FPSWFBasis2D.evaluate_t expects Image instance," + " attempting conversion." + ) + images = Image(images) - images_shape = (images_shape + (1,)) if len(images_shape) == 2 else images_shape - flattened_images = images.reshape( - (images_shape[0] * images_shape[1], images_shape[2]), order="F" - ) + flattened_images = images[:, self._disk_mask] - flattened_images = flattened_images[self._disk_mask_vec, :] - coefficients = self.samples_conj_transpose.dot(flattened_images) - return coefficients.T + return flattened_images @ self.samples_conj_transpose def evaluate(self, coefficients): """ Evaluate coefficients in standard 2D coordinate basis from those in PSWF basis :param coeffcients: A coefficient vector (or an array of coefficient - vectors) in PSWF basis to be evaluated. - :return : The evaluation of the coefficient vector(s) in standard 2D - coordinate basis. + vectors) in PSWF basis to be evaluated. (n_image, count) + :return : Image in standard 2D coordinate basis. + """ - coefficients = coefficients.T # RCOPT - # if we got only one vector - if len(coefficients.shape) == 1: - coefficients = coefficients[:, np.newaxis] + # Handle a single coefficient vector or stack of vectors. + coefficients = np.atleast_2d(coefficients) + n_images = coefficients.shape[0] angular_is_zero = np.absolute(self.ang_freqs) == 0 - flatten_images = self.samples[:, angular_is_zero].dot( - coefficients[angular_is_zero] - ) + 2.0 * np.real( - self.samples[:, ~angular_is_zero].dot(coefficients[~angular_is_zero]) + + flatten_images = coefficients[:, angular_is_zero] @ self.samples[ + angular_is_zero + ] + 2.0 * np.real( + coefficients[:, ~angular_is_zero] @ self.samples[~angular_is_zero] ) - n_images = int(flatten_images.shape[1]) - images = np.zeros((self._image_height, self._image_height, n_images)).astype( - complex_type(self.dtype) + images = np.zeros( + (n_images, self._image_height, self._image_height), dtype=self.dtype ) - images[self._disk_mask, :] = flatten_images - images = np.transpose(images, axes=(1, 0, 2)) - return np.real(images).T # RCOPT + images[:, self._disk_mask] = np.real(flatten_images) + + return Image(images) def _init_pswf_func2d(self, c, eps): """ @@ -249,10 +247,10 @@ def _evaluate_pswf2d_all(self, r, theta, max_ns): :param theta: Phase part to evaluate :param max_ns: List of ints max_ns[i] is max n to to use for N=i, not included. If max_ns[i]<1 N=i won't be used - :return: (len(r), sum(max_ns)) ndarray + :return: (sum(max_ns), len(r)) ndarray Indices are corresponding to the list (N, n) - (0, 0),..., (0, max_ns[0]), (1, 0),..., (1, max_ns[1]),... , (len(max_ns)-1, 0), - (len(max_ns)-1, max_ns[-1]) + (0, 0),..., (max_ns[0], 0), (0, 1),..., (max_ns[1], 1),... , (0, len(max_ns)-1), + (max_ns[-1], len(max_ns)-1) """ max_ns_ints = [int(max_n) for max_n in max_ns] out_mat = [] @@ -271,7 +269,7 @@ def _evaluate_pswf2d_all(self, r, theta, max_ns): pswf_n_n_mat = phase_part * r_radial_part_mat.T out_mat.extend(pswf_n_n_mat) - out_mat = np.array(out_mat, dtype=complex_type(self.dtype)).T + out_mat = np.array(out_mat, dtype=complex_type(self.dtype)) return out_mat def pswf_func2d(self, big_n, n, bandlimit, phi_approximate_error, r, w): diff --git a/src/aspire/reconstruction/estimator.py b/src/aspire/reconstruction/estimator.py index 9de91b9193..d4a08958f0 100644 --- a/src/aspire/reconstruction/estimator.py +++ b/src/aspire/reconstruction/estimator.py @@ -15,6 +15,15 @@ class Estimator: def __init__(self, src, basis, batch_size=512, preconditioner="circulant"): + """ + An object representing a 2*L-by-2*L-by-2*L array containing the non-centered Fourier transform of the mean + least-squares estimator kernel. + Convolving a volume with this kernel is equal to projecting and backproject-ing that volume in each of the + projection directions (with the appropriate amplitude multipliers and CTFs) and averaging over the whole + dataset. + Note that this is a non-centered Fourier transform, so the zero frequency is found at index 1. + """ + self.src = src self.basis = basis self.dtype = self.src.dtype @@ -36,15 +45,6 @@ def __init__(self, src, basis, batch_size=512, preconditioner="circulant"): f" Given src.L={src.L} != {basis.nres}" ) - """ - An object representing a 2*L-by-2*L-by-2*L array containing the non-centered Fourier transform of the mean - least-squares estimator kernel. - Convolving a volume with this kernel is equal to projecting and backproject-ing that volume in each of the - projection directions (with the appropriate amplitude multipliers and CTFs) and averaging over the whole - dataset. - Note that this is a non-centered Fourier transform, so the zero frequency is found at index 1. - """ - def __getattr__(self, name): """Lazy attributes instantiated on first-access""" diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index c4b7d4d077..ac88034cce 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -48,11 +48,10 @@ class ImageSource: objects, depending on unique CTF values found for _rlnDefocusU/_rlnDefocusV etc. """ - """ - The metadata_fields dictionary below specifies default data types of certain key fields used in the codebase. - The STAR file used to initialize subclasses of ImageSource may well contain other columns not found below; these - additional columns are available when read, and they default to the pandas data type 'object'. - """ + # The metadata_fields dictionary below specifies default data types of certain key fields used in the codebase. + # The STAR file used to initialize subclasses of ImageSource may well contain other columns not found below; these + # additional columns are available when read, and they default to the pandas data type 'object'. + metadata_fields = { "_rlnVoltage": float, "_rlnDefocusU": float, diff --git a/tests/saved_test_data/fpswf2d_vcoeffs_out_8_8.npy b/tests/saved_test_data/fpswf2d_vcoeffs_out_8_8.npy deleted file mode 100644 index 10425b3cc5..0000000000 Binary files a/tests/saved_test_data/fpswf2d_vcoeffs_out_8_8.npy and /dev/null differ diff --git a/tests/saved_test_data/fpswf2d_xcoeffs_out_8_8.npy b/tests/saved_test_data/fpswf2d_xcoeffs_out_8_8.npy deleted file mode 100644 index c429cb8b74..0000000000 Binary files a/tests/saved_test_data/fpswf2d_xcoeffs_out_8_8.npy and /dev/null differ diff --git a/tests/test_FPSWFbasis2D.py b/tests/test_FPSWFbasis2D.py index 7950f308d9..8bc0ab215a 100644 --- a/tests/test_FPSWFbasis2D.py +++ b/tests/test_FPSWFbasis2D.py @@ -4,6 +4,7 @@ import numpy as np from aspire.basis import FPSWFBasis2D +from aspire.image import Image DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -16,12 +17,21 @@ def tearDown(self): pass def testFPSWFBasis2DEvaluate_t(self): - # RCOPT, this image reference is a single image 8,8. Transpose no needed. - images = np.load(os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy")) + img_ary = np.load( + os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy") + ).T # RCOPT + images = Image(img_ary) + result = self.basis.evaluate_t(images) + result_ary = self.basis.evaluate_t(img_ary) + + # Confirm output from passing ndarray or Image is the same + self.assertTrue(np.allclose(result, result_ary)) + coeffs = np.load( - os.path.join(DATA_DIR, "fpswf2d_vcoeffs_out_8_8.npy") + os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") ).T # RCOPT + # make sure both real and imaginary parts are consistent. self.assertTrue( np.allclose(np.real(result), np.real(coeffs)) @@ -30,10 +40,8 @@ def testFPSWFBasis2DEvaluate_t(self): def testFPSWFBasis2DEvaluate(self): coeffs = np.load( - os.path.join(DATA_DIR, "fpswf2d_vcoeffs_out_8_8.npy") + os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") ).T # RCOPT result = self.basis.evaluate(coeffs) - images = np.load( - os.path.join(DATA_DIR, "fpswf2d_xcoeffs_out_8_8.npy") - ).T # RCOPT - self.assertTrue(np.allclose(result, images)) + images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT + self.assertTrue(np.allclose(result.asnumpy(), images)) diff --git a/tests/test_PSWFbasis2D.py b/tests/test_PSWFbasis2D.py index 0f7e05f78d..2cf34feacf 100644 --- a/tests/test_PSWFbasis2D.py +++ b/tests/test_PSWFbasis2D.py @@ -4,6 +4,7 @@ import numpy as np from aspire.basis import PSWFBasis2D +from aspire.image import Image DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -16,13 +17,21 @@ def tearDown(self): pass def testPSWFBasis2DEvaluate_t(self): - images = np.load( + img_ary = np.load( os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy") ).T # RCOPT + images = Image(img_ary) + result = self.basis.evaluate_t(images) + result_ary = self.basis.evaluate_t(img_ary) + + # Confirm output from passing ndarray or Image is the same + self.assertTrue(np.allclose(result, result_ary)) + coeffs = np.load( os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") ).T # RCOPT + # make sure both real and imaginary parts are consistent. self.assertTrue( np.allclose(np.real(result), np.real(coeffs)) @@ -35,4 +44,4 @@ def testPSWFBasis2DEvaluate(self): ).T # RCOPT result = self.basis.evaluate(coeffs) images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT - self.assertTrue(np.allclose(result, images)) + self.assertTrue(np.allclose(result.asnumpy(), images))