Skip to content

Commit f4b5b56

Browse files
committed
Separate convert_sh_to_sf
1 parent 95ec0f7 commit f4b5b56

File tree

2 files changed

+60
-42
lines changed

2 files changed

+60
-42
lines changed

scilpy/reconst/fodf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,15 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis,
127127

128128
def _fit_from_model_parallel(args):
129129
(model, data, chunk_id) = args
130-
sub_fit_array = _fit_from_model_2d(data, model)
130+
sub_fit_array = _fit_from_model_loop(data, model)
131131

132132
return chunk_id, sub_fit_array
133133

134134

135-
def _fit_from_model_2d(data, model):
135+
def _fit_from_model_loop(data, model):
136136
"""
137-
Loops on 2D data and fits each voxel separately
137+
Loops on 2D data and fits each voxel separately.
138+
See fit_from_model for more information.
138139
"""
139140
# Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels.
140141
tmp_fit_array = np.zeros((data.shape[0],), dtype='object')
@@ -191,7 +192,7 @@ def fit_from_model(model, data, mask=None, nbr_processes=None):
191192
# Separating the case nbr_processes=1 to help get good coverage metrics
192193
# (codecov does not deal well with multiprocessing)
193194
if nbr_processes == 1:
194-
tmp_fit_array = _fit_from_model_2d(data, model)
195+
tmp_fit_array = _fit_from_model_loop(data, model)
195196
else:
196197
# Separate the data in chunks of len(nbr_processes).
197198
chunks = np.array_split(data, nbr_processes)

scilpy/reconst/sh.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,16 @@ def _peaks_from_sh_parallel(args):
185185
absolute_threshold, min_separation_angle,
186186
npeaks, normalize_peaks, chunk_id, is_symmetric) = args
187187

188-
peak_dirs, peak_values, peak_indices = _peaks_from_sh_2d(
188+
peak_dirs, peak_values, peak_indices = _peaks_from_sh_loop(
189189
shm_coeff, B, sphere, relative_peak_threshold,
190190
absolute_threshold, min_separation_angle, npeaks,
191191
normalize_peaks, is_symmetric)
192192
return chunk_id, peak_dirs, peak_values, peak_indices
193193

194194

195-
def _peaks_from_sh_2d(shm_coeff, B, sphere, relative_peak_threshold,
196-
absolute_threshold, min_separation_angle, npeaks,
197-
normalize_peaks, is_symmetric):
195+
def _peaks_from_sh_loop(shm_coeff, B, sphere, relative_peak_threshold,
196+
absolute_threshold, min_separation_angle, npeaks,
197+
normalize_peaks, is_symmetric):
198198
"""
199199
Loops on 2D (ravelled) data and fits each voxel separately.
200200
See peaks_from_sh for a complete description of parameters.
@@ -307,7 +307,7 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5,
307307
# (codecov does not deal well with multiprocessing)
308308
if nbr_processes == 1:
309309
(tmp_peak_dirs_array, tmp_peak_values_array,
310-
tmp_peak_indices_array) = _peaks_from_sh_2d(
310+
tmp_peak_indices_array) = _peaks_from_sh_loop(
311311
shm_coeff, B, sphere, relative_peak_threshold,
312312
absolute_threshold, min_separation_angle, npeaks,
313313
normalize_peaks, is_symmetric)
@@ -358,13 +358,13 @@ def _maps_from_sh_parallel(args):
358358
(shm_coeff, peak_values, peak_indices, B, sphere,
359359
gfa_thr, chunk_id) = args
360360

361-
res = _maps_from_sh_2d(shm_coeff, peak_values, peak_indices, B,
362-
sphere, gfa_thr)
361+
res = _maps_from_sh_loop(shm_coeff, peak_values, peak_indices, B,
362+
sphere, gfa_thr)
363363
return chunk_id, *res
364364

365365

366-
def _maps_from_sh_2d(shm_coeff, peak_values, peak_indices, B, sphere,
367-
gfa_thr):
366+
def _maps_from_sh_loop(shm_coeff, peak_values, peak_indices, B, sphere,
367+
gfa_thr):
368368
"""
369369
Loops on 2D (ravelled) data and fits each voxel separately.
370370
For a more complete description of parameters, see maps_from_sh.
@@ -462,7 +462,7 @@ def maps_from_sh(shm_coeff, peak_values, peak_indices, sphere,
462462
if nbr_processes == 1:
463463
(tmp_nufo_map_array, tmp_afd_max_array, tmp_afd_sum_array,
464464
tmp_rgb_map_array, tmp_gfa_map_array, tmp_qa_map_array,
465-
all_time_max_odf, all_time_global_max) = _maps_from_sh_2d(
465+
all_time_max_odf, all_time_global_max) = _maps_from_sh_loop(
466466
shm_coeff, peak_values, peak_indices,
467467
B, sphere, gfa_thr)
468468
else:
@@ -539,12 +539,12 @@ def maps_from_sh(shm_coeff, peak_values, peak_indices, sphere,
539539

540540
def _convert_sh_basis_parallel(args):
541541
(sh, B_in, invB_out, chunk_id) = args
542-
sh = _convert_sh_basis_2d(sh, B_in, invB_out)
542+
sh = _convert_sh_basis_loop(sh, B_in, invB_out)
543543

544544
return chunk_id, sh
545545

546546

547-
def _convert_sh_basis_2d(sh, B_in, invB_out):
547+
def _convert_sh_basis_loop(sh, B_in, invB_out):
548548
"""
549549
Loops on 2D (ravelled) data and fits each voxel separately.
550550
For a more complete description of parameters, see convert_sh_basis.
@@ -625,7 +625,7 @@ def convert_sh_basis(shm_coeff, sphere, mask=None,
625625
# Separating the case nbr_processes=1 to help get good coverage metrics
626626
# (codecov does not deal well with multiprocessing)
627627
if nbr_processes == 1:
628-
tmp_shm_coeff_array = _convert_sh_basis_2d(shm_coeff, B_in, invB_out)
628+
tmp_shm_coeff_array = _convert_sh_basis_loop(shm_coeff, B_in, invB_out)
629629
else:
630630
# Separate the data in chunks of len(nbr_processes).
631631
shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes)
@@ -653,17 +653,24 @@ def convert_sh_basis(shm_coeff, sphere, mask=None,
653653

654654

655655
def _convert_sh_to_sf_parallel(args):
656-
sh = args[0]
657-
B_in = args[1]
658-
new_output_dim = args[2]
659-
chunk_id = args[3]
656+
(sh, B_in, new_output_dim, chunk_id) = args
657+
sf = _convert_sh_to_sf_loop(sh, new_output_dim, B_in)
658+
return chunk_id, sf
659+
660+
661+
def _convert_sh_to_sf_loop(sh, new_output_dim, B_in):
662+
"""
663+
Loops on 2D data and fits each voxel separately.
664+
See convert_sh_to_sf for more information.
665+
"""
666+
# Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels.
660667
sf = np.zeros((sh.shape[0], new_output_dim), dtype=np.float32)
661668

662669
for idx in range(sh.shape[0]):
663670
if sh[idx].any():
664671
sf[idx] = np.dot(sh[idx], B_in)
665672

666-
return chunk_id, sf
673+
return sf
667674

668675

669676
def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32",
@@ -716,30 +723,40 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32",
716723
if mask is None:
717724
mask = np.sum(shm_coeff, axis=3).astype(bool)
718725

726+
output_dim = len(sphere.vertices)
727+
new_shape = data_shape[:3] + (output_dim,)
728+
719729
# Ravel the first 3 dimensions while keeping the 4th intact, like a list of
720-
# 1D time series voxels. Then separate it in chunks of len(nbr_processes).
730+
# 1D time series voxels.
721731
shm_coeff = shm_coeff[mask].reshape(
722732
(np.count_nonzero(mask), data_shape[3]))
723-
shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes)
724-
chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks])
725-
726-
pool = multiprocessing.Pool(nbr_processes)
727-
results = pool.map(_convert_sh_to_sf_parallel,
728-
zip(shm_coeff_chunks,
729-
itertools.repeat(B_in),
730-
itertools.repeat(len(sphere.vertices)),
731-
np.arange(len(shm_coeff_chunks))))
732-
pool.close()
733-
pool.join()
734-
735-
# Re-assemble the chunk together in the original shape.
736-
new_shape = data_shape[:3] + (len(sphere.vertices),)
737-
sf_array = np.zeros(new_shape, dtype=dtype)
738-
tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]),
739-
dtype=dtype)
740-
for i, new_sf in results:
741-
tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf
742733

734+
# Separating the case nbr_processes=1 to help get good coverage metrics
735+
# (codecov does not deal well with multiprocessing)
736+
if nbr_processes == 1:
737+
tmp_sf_array = _convert_sh_to_sf_loop(shm_coeff, output_dim, B_in)
738+
else:
739+
# Separate the data in chunks of len(nbr_processes).
740+
shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes)
741+
742+
pool = multiprocessing.Pool(nbr_processes)
743+
results = pool.map(_convert_sh_to_sf_parallel,
744+
zip(shm_coeff_chunks,
745+
itertools.repeat(B_in),
746+
itertools.repeat(output_dim),
747+
np.arange(len(shm_coeff_chunks))))
748+
pool.close()
749+
pool.join()
750+
751+
# Re-assemble the chunk together.
752+
chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks])
753+
tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]),
754+
dtype=dtype)
755+
for i, new_sf in results:
756+
tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf
757+
758+
# Bring back to the original shape
759+
sf_array = np.zeros(new_shape, dtype=dtype)
743760
sf_array[mask] = tmp_sf_array
744761

745762
return sf_array

0 commit comments

Comments
 (0)