diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 3d3a6c6..9e71118 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -1,30 +1,105 @@ import gsd import gsd.hoomd +import freud -def frame_get_type_position(gsdfile, typename, frame=-1): +def get_type_position(type_name, gsd_file=None, snap=None, gsd_frame=-1): """ This function returns the positions of a particular particle - type from a frame of a gsd trajectory file. + type from a frame of a gsd trajectory file or from a snapshot. + Pass in either a gsd file or a snapshot, but not both. Parameters ---------- - gsdfile : str, - filename of the gsd trajectory - typename : str, + type_name : str, name of particles of which to get the positions (found in gsd.hoomd.Snapshot.particles.types) - frame : int, + gsd_file : str, + filename of the gsd trajectory (default = None) + snap : gsd.hoomd.Snapshot + Trajectory snapshot (default = None) + gsd_frame : int, frame number to get positions from. Supports - negative indexing. (default -1) + negative indexing. (default = -1) Returns ------- numpy.ndarray """ - with gsd.hoomd.open(name=gsdfile, mode='rb') as f: - snap = f[frame] - typepos = snap.particles.position[ - snap.particles.typeid == snap.particles.types.index(typename) + snap = _validate_inputs(gsd_file, snap, gsd_frame) + type_pos = snap.particles.position[ + snap.particles.typeid == snap.particles.types.index(type_name) ] - return typepos + return type_pos + +def get_all_types(gsd_file=None, snap=None, gsd_frame=-1): + """ + Returns all particle types in a hoomd trajectory + + Parameters + ---------- + gsd_file : str, + filename of the gsd trajectory (default = None) + snap : gsd.hoomd.Snapshot + Trajectory snapshot (default = None) + gsd_frame : int, + frame number to get positions from. Supports + negative indexing. (default = -1) + + Returns + ------- + numpy.ndarray + """ + snap = _validate_inputs(gsd_file, snap, gsd_frame) + return snap.particles.types + +def snap_molecule_cluster(gsd_file=None, snap=None, gsd_frame=-1): + """Find molecule index for each particle. + + Compute clusters of bonded molecules and return an array of the molecule + index of each particle. + Pass in either a gsd file or a snapshot, but not both + + Parameters + ---------- + gsd_file : str, + Filename of the gsd trajectory (default = None) + snap : gsd.hoomd.Snapshot + Trajectory snapshot. (default = None) + gsd_frame : int, + Frame number of gsd_file to use in computing clusters. (default = -1) + + Returns + ------- + numpy array (N_particles,) + """ + snap = _validate_inputs(gsd_file, snap, gsd_frame) + system = freud.AABBQuery.from_system(snap) + n_query_points = n_points = snap.particles.N + query_point_indices = snap.bonds.group[:, 0] + point_indices = snap.bonds.group[:, 1] + distances = system.box.compute_distances( + system.points[query_point_indices], system.points[point_indices] + ) + nlist = freud.NeighborList.from_arrays( + n_query_points, n_points, query_point_indices, point_indices, distances + ) + cluster = freud.cluster.Cluster() + cluster.compute(system=system, neighbors=nlist) + return cluster + + +def _validate_inputs(gsd_file, snap, gsd_frame): + if all([gsd_file, snap]): + raise ValueError("Only pass in one of snapshot, gsd_file") + if gsd_file: + assert isinstance(gsd_frame, int) + try: + with gsd.hoomd.open(name=gsd_file, mode='rb') as f: + snap = f[gsd_frame] + except Exception as e: + print("Unable to open the gsd_file") + raise e + elif snap: + assert isinstance(snap, gsd.hoomd.Snapshot) + return snap diff --git a/cmeutils/tests/base_test.py b/cmeutils/tests/base_test.py index a43dd40..f43c521 100644 --- a/cmeutils/tests/base_test.py +++ b/cmeutils/tests/base_test.py @@ -13,19 +13,35 @@ def test_gsd(self, tmp_path): create_gsd(filename) return filename + @pytest.fixture + def test_gsd_bonded(self, tmp_path): + filename = tmp_path / "test.gsd" + create_gsd(filename, add_bonds=True) + return filename + @pytest.fixture + def test_snap(self, test_gsd): + with gsd.hoomd.open(name=test_gsd, mode="rb") as f: + snap = f[-1] + return snap -def create_frame(i, seed=42): +def create_frame(i, add_bonds, seed=42): np.random.seed(seed) s = gsd.hoomd.Snapshot() s.configuration.step = i - s.particles.N = 4 + s.particles.N = 5 s.particles.types = ['A', 'B'] - s.particles.typeid = [0,0,1,1] - s.particles.position = np.random.random(size=(4,3)) + s.particles.typeid = [0,0,1,1,1] + s.particles.position = np.random.random(size=(5,3)) s.configuration.box = [3, 3, 3, 0, 0, 0] + if add_bonds: + s.bonds.N = 2 + s.bonds.types = ['AB'] + s.bonds.typeid = [0, 0] + s.bonds.group = [[0, 2], [1, 3]] + s.validate() return s -def create_gsd(filename): +def create_gsd(filename, add_bonds=False): with gsd.hoomd.open(name=filename, mode='wb') as f: - f.extend((create_frame(i) for i in range(10))) + f.extend((create_frame(i, add_bonds=add_bonds) for i in range(10))) diff --git a/cmeutils/tests/test_gsd.py b/cmeutils/tests/test_gsd.py index 0638fe7..34a31ab 100644 --- a/cmeutils/tests/test_gsd.py +++ b/cmeutils/tests/test_gsd.py @@ -1,12 +1,34 @@ import numpy as np +import pytest from cmeutils import gsd_utils from base_test import BaseTest class TestGSD(BaseTest): - def test_frame_get_type_position(self, test_gsd): - pos_array = gsd_utils.frame_get_type_position(test_gsd, 'A') + + def test_get_type_position(self, test_gsd): + from cmeutils.gsd_utils import get_type_position + + pos_array = get_type_position(gsd_file = test_gsd, type_name = 'A') assert type(pos_array) is type(np.array([])) assert pos_array.shape == (2,3) + def test_validate_inputs(self, test_gsd, test_snap): + # Catch error with both gsd_file and snap are passed + with pytest.raises(ValueError): + snap = gsd_utils._validate_inputs(test_gsd, test_snap, 1) + with pytest.raises(AssertionError): + snap = gsd_utils._validate_inputs(test_gsd, None, 1.0) + with pytest.raises(AssertionError): + snap = gsd_utils._validate_inputs(None, test_gsd, 1) + with pytest.raises(OSError): + gsd_utils._validate_inputs("bad_gsd_file", None, 0) + + def test_get_all_types(self, test_gsd): + types = gsd_utils.get_all_types(test_gsd) + assert types == ['A', 'B'] + + def test_snap_molecule_cluster(self, test_gsd_bonded): + cluster = gsd_utils.snap_molecule_cluster(gsd_file=test_gsd_bonded) +