Skip to content

Commit f53a989

Browse files
committed
Continue increasing test coverage.
1 parent 9191a5b commit f53a989

File tree

5 files changed

+165
-56
lines changed

5 files changed

+165
-56
lines changed

swiftgalaxy/reader.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,7 @@ def _mask_dataset(self, mask: slice) -> None:
848848
else:
849849
if self._swiftgalaxy._spatial_mask is None:
850850
# get a count of particles in the box
851-
num_part = self._particle_dataset.metadata.num_part[
852-
particle_metadata.particle_type
853-
]
851+
num_part = getattr(self.metadata, f"n_{particle_metadata.group_name}")
854852
else:
855853
# get a count of particles in the spatial mask region
856854
num_part = np.sum(
@@ -1029,20 +1027,16 @@ def spherical_coordinates(self) -> _CoordinateHelper:
10291027
a**0, scale_factor=r.cosmo_factor.scale_factor
10301028
),
10311029
)
1032-
if self.cylindrical_coordinates is not None:
1030+
if self._cylindrical_coordinates is not None:
10331031
phi = self.cylindrical_coordinates.phi
10341032
else:
1035-
phi = cosmo_array(
1033+
phi = (
10361034
np.arctan2(
10371035
self.cartesian_coordinates.y, self.cartesian_coordinates.x
1038-
),
1039-
units=unyt.rad,
1040-
comoving=r.comoving,
1041-
cosmo_factor=cosmo_factor(
1042-
a**0, scale_factor=r.cosmo_factor.scale_factor
1043-
),
1044-
)
1045-
phi[phi < 0] = phi[phi < 0] + 2 * np.pi * unyt.rad
1036+
)
1037+
* unyt.rad
1038+
) # arctan2 returns dimensionless
1039+
phi[phi < 0] += 2 * np.pi * np.ones_like(phi)[phi < 0]
10461040
self._spherical_coordinates = dict(_r=r, _theta=theta, _phi=phi)
10471041
return _CoordinateHelper(
10481042
self._spherical_coordinates,
@@ -1198,19 +1192,13 @@ def cylindrical_coordinates(self) -> _CoordinateHelper:
11981192
if self._spherical_coordinates is not None:
11991193
phi = self.spherical_coordinates.phi
12001194
else:
1201-
# np.where returns ndarray
1202-
phi = np.arctan2(
1203-
self.cartesian_coordinates.y, self.cartesian_coordinates.x
1204-
).view(np.ndarray)
1205-
phi = np.where(phi < 0, phi + 2 * np.pi, phi)
1206-
phi = cosmo_array(
1207-
phi,
1208-
units=unyt.rad,
1209-
comoving=rho.comoving,
1210-
cosmo_factor=cosmo_factor(
1211-
a**0, scale_factor=rho.cosmo_factor.scale_factor
1212-
),
1213-
)
1195+
phi = (
1196+
np.arctan2(
1197+
self.cartesian_coordinates.y, self.cartesian_coordinates.x
1198+
)
1199+
* unyt.rad
1200+
) # arctan2 returns dimensionless
1201+
phi[phi < 0] += 2 * np.pi * np.ones_like(phi)[phi < 0]
12141202
z = self.cartesian_coordinates.z
12151203
self._cylindrical_coordinates = dict(_rho=rho, _phi=phi, _z=z)
12161204
return _CoordinateHelper(

tests/test_coordinate_transformations.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from unyt.testing import assert_allclose_units
55
from swiftsimio.objects import cosmo_array, cosmo_factor, a, cosmo_quantity
66
from scipy.spatial.transform import Rotation
7-
from toysnap import present_particle_types, toysnap_filename, ToyHF
7+
from toysnap import (
8+
present_particle_types,
9+
toysnap_filename,
10+
ToyHF,
11+
create_toysnap,
12+
remove_toysnap,
13+
)
814
from swiftgalaxy import SWIFTGalaxy
915
from swiftgalaxy.reader import _apply_translation, _apply_4transform
1016

@@ -604,6 +610,27 @@ def test_auto_recentering_with_copied_coordinate_frame(self, sg):
604610
toysnap_filename, ToyHF(), auto_recentre=True, coordinate_frame_from=sg
605611
)
606612

613+
def test_invalid_coordinate_frame_from(self, sg):
614+
"""
615+
Check that we get an error if coordinate_frame_from has mismatched internal units.
616+
"""
617+
new_time_unit = u.s
618+
assert sg.metadata.units.time != new_time_unit
619+
sg.metadata.units.time = new_time_unit
620+
try:
621+
create_toysnap()
622+
with pytest.raises(
623+
ValueError, match="Internal units \\(length and time\\) of"
624+
):
625+
SWIFTGalaxy(
626+
toysnap_filename,
627+
ToyHF(),
628+
coordinate_frame_from=sg,
629+
auto_recentre=False,
630+
)
631+
finally:
632+
remove_toysnap()
633+
607634
def test_copied_coordinate_transform(self, sg):
608635
"""
609636
Check that a SWIFTGalaxy initialised to copy the coordinate frame

tests/test_copy.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import numpy as np
12
from copy import copy, deepcopy
23
import pytest
34
import unyt as u
45
from unyt.testing import assert_allclose_units
6+
from swiftgalaxy.masks import MaskCollection
57

68
abstol_m = 1e2 * u.solMass
79
reltol_m = 1.0e-4
@@ -139,3 +141,44 @@ def test_deepcopy_namedcolumn(self, sg):
139141
rtol=reltol_nd,
140142
atol=abstol_nd,
141143
)
144+
145+
146+
class TestCopyMaskCollection:
147+
def test_copy_mask_collection(self):
148+
"""
149+
Test that masks get copied.
150+
"""
151+
mc = MaskCollection(
152+
gas=np.ones(100, dtype=bool),
153+
dark_matter=np.s_[:20],
154+
stars=None,
155+
black_holes=np.arange(3),
156+
)
157+
mc_copy = copy(mc)
158+
assert set(mc_copy.__dict__.keys()) == set(mc.__dict__.keys())
159+
for k in ("gas", "dark_matter", "stars", "black_holes"):
160+
comparison = getattr(mc, k) == getattr(mc_copy, k)
161+
if type(comparison) is bool:
162+
assert comparison
163+
else:
164+
assert all(comparison)
165+
166+
def test_deepcopy_mask_collection(self):
167+
"""
168+
Test that masks get copied along with values. Since the object isn't
169+
really "deep", shallow copy and deepcopy have the same expectation.
170+
"""
171+
mc = MaskCollection(
172+
gas=np.ones(100, dtype=bool),
173+
dark_matter=np.s_[:20],
174+
stars=None,
175+
black_holes=np.arange(3),
176+
)
177+
mc_copy = deepcopy(mc)
178+
assert set(mc_copy.__dict__.keys()) == set(mc.__dict__.keys())
179+
for k in ("gas", "dark_matter", "stars", "black_holes"):
180+
comparison = getattr(mc, k) == getattr(mc_copy, k)
181+
if type(comparison) is bool:
182+
assert comparison
183+
else:
184+
assert all(comparison)

tests/test_creation.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1+
"""
2+
Tests checking that we can create objects, if these fail something
3+
fundamental has gone wrong.
4+
"""
5+
16
from swiftgalaxy import SWIFTGalaxy
2-
from toysnap import (
3-
create_toysnap,
4-
remove_toysnap,
5-
toysnap_filename,
6-
n_g_all,
7-
n_dm_all,
8-
n_s_all,
9-
n_bh_all,
10-
)
7+
from toysnap import create_toysnap, remove_toysnap, toysnap_filename, n_g_1
118

129

1310
class TestSWIFTGalaxyCreation:
@@ -81,30 +78,36 @@ def test_tab_completion(self, sg):
8178
# finally, check that we didn't lazy-load everything!
8279
assert sg.gas._particle_dataset._coordinates is None
8380

84-
def test_no_masks(self):
81+
def test_mask_preloaded_namedcolumn(self):
8582
"""
86-
Check that if we have no masks we read everything in the box (and warn about it).
83+
If namedcolumn data was loaded during evaluation of a mask, it needs to be masked
84+
during initialization.
8785
"""
86+
from toysnap import ToyHF
87+
88+
def load_namedcolumn(method):
89+
90+
def wrapper(self, sg):
91+
sg.gas.hydrogen_ionization_fractions.neutral
92+
return method(self, sg)
93+
94+
return wrapper
95+
96+
# decorate the mask evaluation to load an (unused) namedcolumn
97+
ToyHF._generate_bound_only_mask = load_namedcolumn(
98+
ToyHF._generate_bound_only_mask
99+
)
100+
88101
try:
89102
create_toysnap()
90-
sg = SWIFTGalaxy(
91-
toysnap_filename,
92-
None, # no halo_catalogue is easiest way to get no mask
93-
transforms_like_coordinates={"coordinates", "extra_coordinates"},
94-
transforms_like_velocities={"velocities", "extra_velocities"},
103+
sg = SWIFTGalaxy(toysnap_filename, ToyHF())
104+
# confirm that we loaded a namedcolumn during initialization:
105+
assert (
106+
sg.gas.hydrogen_ionization_fractions._internal_dataset._neutral
107+
is not None
95108
)
96-
# check that extra mask is blank for all particle types:
97-
assert sg._extra_mask.gas is None
98-
assert sg._extra_mask.dark_matter is None
99-
assert sg._extra_mask.stars is None
100-
assert sg._extra_mask.black_holes is None
101-
# check that cell mask is blank for all particle types:
102-
assert sg._spatial_mask is None
103-
# check that we read all the particles:
104-
assert sg.gas.masses.size == n_g_all
105-
assert sg.dark_matter.masses.size == n_dm_all
106-
assert sg.stars.masses.size == n_s_all
107-
assert sg.black_holes.masses.size == n_bh_all
109+
# confirm that it got masked:
110+
assert sg.gas.hydrogen_ionization_fractions.neutral.size == n_g_1
108111
finally:
109112
remove_toysnap()
110113

tests/test_masking.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1+
"""
2+
Tests for applying masks to swiftgalaxy, datasets and named columns.
3+
"""
4+
15
import pytest
26
import numpy as np
37
from unyt.testing import assert_allclose_units
48
from toysnap import present_particle_types
5-
from swiftgalaxy import MaskCollection
9+
from swiftgalaxy import SWIFTGalaxy, MaskCollection
10+
from toysnap import (
11+
create_toysnap,
12+
remove_toysnap,
13+
toysnap_filename,
14+
n_g_all,
15+
n_dm_all,
16+
n_s_all,
17+
n_bh_all,
18+
)
619

720
abstol_nd = 1.0e-4
821
reltol_nd = 1.0e-4
@@ -72,6 +85,41 @@ def test_namedcolumn_masked(self, sg, before_load):
7285
neutral_before[mask], neutral, rtol=reltol_nd, atol=abstol_nd
7386
)
7487

88+
def test_mask_without_spatial_mask(self):
89+
"""
90+
Check that if we have no masks we read everything in the box (and warn about it).
91+
Then that we can still apply an extra mask, and a second one (there's specific
92+
logic for applying two consecutively).
93+
"""
94+
try:
95+
create_toysnap()
96+
sg = SWIFTGalaxy(
97+
toysnap_filename,
98+
None, # no halo_catalogue is easiest way to get no mask
99+
transforms_like_coordinates={"coordinates", "extra_coordinates"},
100+
transforms_like_velocities={"velocities", "extra_velocities"},
101+
)
102+
# check that extra mask is blank for all particle types:
103+
assert sg._extra_mask.gas is None
104+
assert sg._extra_mask.dark_matter is None
105+
assert sg._extra_mask.stars is None
106+
assert sg._extra_mask.black_holes is None
107+
# check that cell mask is blank for all particle types:
108+
assert sg._spatial_mask is None
109+
# check that we read all the particles:
110+
assert sg.gas.masses.size == n_g_all
111+
assert sg.dark_matter.masses.size == n_dm_all
112+
assert sg.stars.masses.size == n_s_all
113+
assert sg.black_holes.masses.size == n_bh_all
114+
# now apply an extra mask
115+
sg.mask_particles(MaskCollection(gas=np.s_[:1000]))
116+
assert sg.gas.masses.size == 1000
117+
# and the second consecutive one
118+
sg.mask_particles(MaskCollection(gas=np.s_[:100]))
119+
assert sg.gas.masses.size == 100
120+
finally:
121+
remove_toysnap()
122+
75123

76124
class TestMaskingParticleDatasets:
77125
@pytest.mark.parametrize("particle_name", present_particle_types.values())

0 commit comments

Comments
 (0)