Skip to content

Commit 9191a5b

Browse files
committed
Continue increasing test coverage.
1 parent e9c7afc commit 9191a5b

File tree

6 files changed

+188
-34
lines changed

6 files changed

+188
-34
lines changed

swiftgalaxy/reader.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _apply_translation(coords: cosmo_array, offset: cosmo_array) -> cosmo_array:
103103
offset = offset.to_comoving()
104104
elif hasattr(offset, "comoving") and not coords.comoving:
105105
offset = offset.to_physical()
106-
elif not hasattr(offset, "comoving"):
106+
else: # not hasattr(offset, "comoving")
107107
msg = (
108108
"Translation assumed to be in comoving (not physical) coordinates."
109109
if coords.comoving
@@ -1125,10 +1125,13 @@ def spherical_velocities(self) -> _CoordinateHelper:
11251125
+ _sin_t * _sin_p * self.cartesian_velocities.y
11261126
- _cos_t * self.cartesian_velocities.z
11271127
)
1128-
v_p = (
1129-
-_sin_p * self.cartesian_velocities.x
1130-
+ _cos_p * self.cartesian_velocities.y
1131-
)
1128+
if self._cylindrical_velocities is not None:
1129+
v_p = self.cylindrical_velocities.phi
1130+
else:
1131+
v_p = (
1132+
-_sin_p * self.cartesian_velocities.x
1133+
+ _cos_p * self.cartesian_velocities.y
1134+
)
11321135
self._spherical_velocities = dict(_v_r=v_r, _v_t=v_t, _v_p=v_p)
11331136
return _CoordinateHelper(
11341137
self._spherical_velocities,
@@ -1531,7 +1534,7 @@ def __init__(
15311534
self._velocity_like_transform = np.eye(4)
15321535
if self.halo_catalogue is None:
15331536
# in server mode we don't have a halo_catalogue yet
1534-
pass
1537+
self._spatial_mask = getattr(self, "_spatial_mask", None)
15351538
elif self.halo_catalogue._user_spatial_offsets is not None:
15361539
self._spatial_mask = self.halo_catalogue._get_user_spatial_mask(
15371540
self.snapshot_filename
@@ -2022,13 +2025,18 @@ def _transform(self, transform4: cosmo_array, boost: bool = False) -> None:
20222025
if boost
20232026
else self.transforms_like_coordinates
20242027
)
2028+
transform_units = (
2029+
self.metadata.units.length / self.metadata.units.time
2030+
if boost
2031+
else self.metadata.units.length
2032+
)
20252033
for particle_name in self.metadata.present_group_names:
20262034
dataset = getattr(self, particle_name)._particle_dataset
20272035
for field_name in transformable:
20282036
field_data = getattr(dataset, f"_{field_name}")
20292037
if field_data is not None:
20302038
field_data = _apply_4transform(
2031-
field_data, transform4.to_value(), transform4.units
2039+
field_data, transform4, transform_units
20322040
)
20332041
setattr(dataset, f"_{field_name}", field_data)
20342042
if boost:

tests/test_coordinate_transformations.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,18 @@ def test_translate(self, sg, particle_name, coordinate_name, before_load):
363363
xyz = getattr(getattr(sg, particle_name), f"{coordinate_name}")
364364
assert_allclose_units(xyz_before + translation, xyz, rtol=1.0e-4, atol=abstol_c)
365365

366+
def test_translate_warn_comoving_missing(self, sg):
367+
"""
368+
If the translation does not have comoving information issue a warning.
369+
"""
370+
translation = u.unyt_array(
371+
[1, 1, 1],
372+
units=u.Mpc,
373+
)
374+
msg = "Translation assumed to be in comoving"
375+
with pytest.warns(RuntimeWarning, match=msg):
376+
sg.translate(translation)
377+
366378
@pytest.mark.parametrize("velocity_name", ("velocities", "extra_velocities"))
367379
@pytest.mark.parametrize("particle_name", present_particle_types.values())
368380
@pytest.mark.parametrize("before_load", (True, False))
@@ -425,6 +437,64 @@ def test_box_wrap(self, sg, particle_name, coordinate_name):
425437
xyz = getattr(getattr(sg, particle_name), f"{coordinate_name}")
426438
assert_allclose_units(xyz_before, xyz, rtol=1.0e-4, atol=abstol_c)
427439

440+
@pytest.mark.parametrize("coordinate_name", ("coordinates", "extra_coordinates"))
441+
@pytest.mark.parametrize("particle_name", present_particle_types.values())
442+
@pytest.mark.parametrize("before_load", (True, False))
443+
def test_transform(self, sg, particle_name, coordinate_name, before_load):
444+
"""
445+
Check that affine transformation works.
446+
"""
447+
xyz_before = getattr(getattr(sg, particle_name), f"{coordinate_name}")
448+
if before_load:
449+
setattr(
450+
getattr(sg, particle_name)._particle_dataset,
451+
f"_{coordinate_name}",
452+
None,
453+
)
454+
translation = cosmo_array(
455+
[1, 1, 1],
456+
units=u.Mpc,
457+
comoving=True,
458+
cosmo_factor=cosmo_factor(a**1, scale_factor=1.0),
459+
)
460+
transform = np.eye(4)
461+
transform[:3, :3] = rot
462+
transform[3, :3] = translation.to_comoving_value(u.Mpc)
463+
sg._transform(transform)
464+
xyz = getattr(getattr(sg, particle_name), f"{coordinate_name}")
465+
assert_allclose_units(
466+
xyz_before.dot(rot) + translation, xyz, rtol=1.0e-4, atol=abstol_c
467+
)
468+
469+
@pytest.mark.parametrize("coordinate_name", ("velocities", "extra_velocities"))
470+
@pytest.mark.parametrize("particle_name", present_particle_types.values())
471+
@pytest.mark.parametrize("before_load", (True, False))
472+
def test_transform_velocity(self, sg, particle_name, coordinate_name, before_load):
473+
"""
474+
Check that affine transformation works.
475+
"""
476+
xyz_before = getattr(getattr(sg, particle_name), f"{coordinate_name}")
477+
if before_load:
478+
setattr(
479+
getattr(sg, particle_name)._particle_dataset,
480+
f"_{coordinate_name}",
481+
None,
482+
)
483+
translation = cosmo_array(
484+
[100, 100, 100],
485+
units=u.km / u.s,
486+
comoving=True,
487+
cosmo_factor=cosmo_factor(a**0, scale_factor=1.0),
488+
)
489+
transform = np.eye(4)
490+
transform[:3, :3] = rot
491+
transform[3, :3] = translation.to_comoving_value(u.km / u.s)
492+
sg._transform(transform, boost=True)
493+
xyz = getattr(getattr(sg, particle_name), f"{coordinate_name}")
494+
assert_allclose_units(
495+
xyz_before.dot(rot) + translation, xyz, rtol=1.0e-4, atol=abstol_v
496+
)
497+
428498

429499
class TestSequentialTransformations:
430500
@pytest.mark.parametrize("before_load", (True, False))
@@ -640,7 +710,48 @@ def test_comoving_physical_conversion(self, comoving):
640710
comoving=comoving,
641711
cosmo_factor=cosmo_factor(a**1, scale_factor=1.0),
642712
)
643-
# identity 4transform:
644-
transform = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]])
713+
transform = np.eye(4) # identity 4transform
645714
result = _apply_4transform(coords, transform, transform_units=u.Mpc)
646715
assert result.comoving == comoving
716+
717+
718+
class TestCoordinateProperties:
719+
720+
def test_centre(self, sg):
721+
"""
722+
Check the centre attribute.
723+
"""
724+
new_centre = cosmo_array(
725+
[1, 2, 3],
726+
units=u.Mpc,
727+
comoving=True,
728+
cosmo_factor=cosmo_factor(a**1, scale_factor=1.0),
729+
)
730+
sg.recentre(new_centre)
731+
assert_allclose_units(
732+
sg.halo_catalogue.centre + new_centre,
733+
sg.centre,
734+
)
735+
736+
def test_velocity_centre(self, sg):
737+
"""
738+
Check the velocity_centre attribute.
739+
"""
740+
new_centre = cosmo_array(
741+
[100, 200, 300],
742+
units=u.km / u.s,
743+
comoving=True,
744+
cosmo_factor=cosmo_factor(a**0, scale_factor=1.0),
745+
)
746+
sg.recentre_velocity(new_centre)
747+
assert_allclose_units(
748+
sg.halo_catalogue.velocity_centre + new_centre,
749+
sg.velocity_centre,
750+
)
751+
752+
def test_rotation(self, sg):
753+
"""
754+
Check the rotation attribute.
755+
"""
756+
sg.rotate(Rotation.from_matrix(rot))
757+
assert np.allclose(sg.rotation.as_matrix(), rot)

tests/test_copy.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import copy, deepcopy
2+
import pytest
23
import unyt as u
34
from unyt.testing import assert_allclose_units
45

@@ -24,13 +25,19 @@ def test_copy_sg(self, sg):
2425
is None
2526
)
2627

27-
def test_deepcopy_sg(self, sg):
28+
@pytest.mark.parametrize("derived_coords_initialized", [True, False])
29+
def test_deepcopy_sg(self, sg, derived_coords_initialized):
2830
"""
2931
Test that dataset arrays get copied on deep copy.
3032
"""
3133
# lazy load a dataset and a named column
3234
sg.gas.masses
3335
sg.gas.hydrogen_ionization_fractions.neutral
36+
if derived_coords_initialized:
37+
sg.gas.spherical_coordinates
38+
sg.gas.spherical_velocities
39+
sg.gas.cylindrical_coordinates
40+
sg.gas.cylindrical_velocities
3441
sg_copy = deepcopy(sg)
3542
# check private attribute to not trigger lazy loading
3643
assert_allclose_units(
@@ -45,6 +52,28 @@ def test_deepcopy_sg(self, sg):
4552
rtol=reltol_nd,
4653
atol=abstol_nd,
4754
)
55+
if derived_coords_initialized:
56+
assert_allclose_units(
57+
sg.gas.spherical_coordinates.r,
58+
sg_copy.gas._spherical_coordinates["_r"],
59+
)
60+
assert_allclose_units(
61+
sg.gas.spherical_velocities.r,
62+
sg_copy.gas._spherical_velocities["_v_r"],
63+
)
64+
assert_allclose_units(
65+
sg.gas.cylindrical_coordinates.rho,
66+
sg_copy.gas._cylindrical_coordinates["_rho"],
67+
)
68+
assert_allclose_units(
69+
sg.gas.cylindrical_velocities.rho,
70+
sg_copy.gas._cylindrical_velocities["_v_rho"],
71+
)
72+
else:
73+
assert sg_copy.gas._spherical_coordinates is None
74+
assert sg_copy.gas._spherical_velocities is None
75+
assert sg_copy.gas._cylindrical_coordinates is None
76+
assert sg_copy.gas._cylindrical_velocities is None
4877

4978

5079
class TestCopyDataset:

tests/test_creation.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import pytest
2-
from swiftgalaxy import SWIFTGalaxy, Standalone
1+
from swiftgalaxy import SWIFTGalaxy
32
from toysnap import (
43
create_toysnap,
54
remove_toysnap,
@@ -9,8 +8,6 @@
98
n_s_all,
109
n_bh_all,
1110
)
12-
from swiftsimio.objects import cosmo_array, cosmo_factor, a
13-
import unyt as u
1411

1512

1613
class TestSWIFTGalaxyCreation:
@@ -90,26 +87,9 @@ def test_no_masks(self):
9087
"""
9188
try:
9289
create_toysnap()
93-
with pytest.warns(UserWarning, match="No spatial_offsets provided."):
94-
sa = Standalone(
95-
extra_mask=None,
96-
centre=cosmo_array(
97-
[0, 0, 0],
98-
u.Mpc,
99-
comoving=True,
100-
cosmo_factor=cosmo_factor(a**1, 1.0),
101-
),
102-
velocity_centre=cosmo_array(
103-
[0, 0, 0],
104-
u.km / u.s,
105-
comoving=True,
106-
cosmo_factor=cosmo_factor(a**0, 1.0),
107-
),
108-
spatial_offsets=None,
109-
)
11090
sg = SWIFTGalaxy(
11191
toysnap_filename,
112-
sa,
92+
None, # no halo_catalogue is easiest way to get no mask
11393
transforms_like_coordinates={"coordinates", "extra_coordinates"},
11494
transforms_like_velocities={"velocities", "extra_velocities"},
11595
)
@@ -119,8 +99,7 @@ def test_no_masks(self):
11999
assert sg._extra_mask.stars is None
120100
assert sg._extra_mask.black_holes is None
121101
# check that cell mask is blank for all particle types:
122-
for cell_mask in sg._spatial_mask.cell_mask.values():
123-
assert cell_mask.all()
102+
assert sg._spatial_mask is None
124103
# check that we read all the particles:
125104
assert sg.gas.masses.size == n_g_all
126105
assert sg.dark_matter.masses.size == n_dm_all

tests/test_derived_coordinates.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ def test_spherical_velocity_phi(self, sg, particle_name, alias):
206206
spherical_v_phi, v_phi_from_cartesian, rtol=1.0e-4, atol=abstol_v
207207
)
208208

209+
@pytest.mark.parametrize("ctype", ["coordinates", "velocities"])
210+
def test_copy_from_cylindrical(self, sg, ctype):
211+
"""
212+
Check that copying the azimuth from cylindrical if already evaluated works.
213+
"""
214+
getattr(sg.gas, f"cylindrical_{ctype}").phi # trigger evaluation
215+
assert (
216+
getattr(sg.gas, f"spherical_{ctype}").phi
217+
is getattr(sg.gas, f"cylindrical_{ctype}").phi
218+
)
219+
209220

210221
class TestCylindricalCoordinates:
211222
@pytest.mark.parametrize("particle_name", present_particle_types.values())
@@ -329,6 +340,17 @@ def test_cylindrical_velocity_z(self, sg, particle_name, alias):
329340
cylindrical_v_z, v_z_from_cartesian, rtol=1.0e-4, atol=abstol_v
330341
)
331342

343+
@pytest.mark.parametrize("ctype", ["coordinates", "velocities"])
344+
def test_copy_from_spherical(self, sg, ctype):
345+
"""
346+
Check that copying the azimuth from spherical if already evaluated works.
347+
"""
348+
getattr(sg.gas, f"spherical_{ctype}").phi # trigger evaluation
349+
assert (
350+
getattr(sg.gas, f"cylindrical_{ctype}").phi
351+
is getattr(sg.gas, f"spherical_{ctype}").phi
352+
)
353+
332354

333355
class TestInteractionWithCoordinateTransformations:
334356
@pytest.mark.parametrize("coordinate_type", ("coordinates", "velocities"))

tests/test_str.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ def test_namedcolumn_fullname(self, sg):
2424
sg.gas.hydrogen_ionization_fractions._fullname
2525
== "gas.hydrogen_ionization_fractions"
2626
)
27+
28+
def test_sg_string(self, sg):
29+
string = str(sg)
30+
assert "SWIFTGalaxy at" in string
31+
assert repr(sg) == string

0 commit comments

Comments
 (0)