Skip to content

Standardize file type handling in gsd_utils #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 87 additions & 12 deletions cmeutils/gsd_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,105 @@
import gsd
import gsd.hoomd
import freud


def frame_get_type_position(gsdfile, typename, frame=-1):
def get_type_position(type_name, gsd_file=None, snap=None, gsd_frame=-1):
"""
This function returns the positions of a particular particle
type from a frame of a gsd trajectory file.
type from a frame of a gsd trajectory file or from a snapshot.
Pass in either a gsd file or a snapshot, but not both.

Parameters
----------
gsdfile : str,
filename of the gsd trajectory
typename : str,
type_name : str,
name of particles of which to get the positions
(found in gsd.hoomd.Snapshot.particles.types)
frame : int,
gsd_file : str,
filename of the gsd trajectory (default = None)
snap : gsd.hoomd.Snapshot
Trajectory snapshot (default = None)
gsd_frame : int,
frame number to get positions from. Supports
negative indexing. (default -1)
negative indexing. (default = -1)

Returns
-------
numpy.ndarray
"""
with gsd.hoomd.open(name=gsdfile, mode='rb') as f:
snap = f[frame]
typepos = snap.particles.position[
snap.particles.typeid == snap.particles.types.index(typename)
snap = _validate_inputs(gsd_file, snap, gsd_frame)
type_pos = snap.particles.position[
snap.particles.typeid == snap.particles.types.index(type_name)
]
return typepos
return type_pos

def get_all_types(gsd_file=None, snap=None, gsd_frame=-1):
"""
Returns all particle types in a hoomd trajectory

Parameters
----------
gsd_file : str,
filename of the gsd trajectory (default = None)
snap : gsd.hoomd.Snapshot
Trajectory snapshot (default = None)
gsd_frame : int,
frame number to get positions from. Supports
negative indexing. (default = -1)

Returns
-------
numpy.ndarray
"""
snap = _validate_inputs(gsd_file, snap, gsd_frame)
return snap.particles.types

def snap_molecule_cluster(gsd_file=None, snap=None, gsd_frame=-1):
"""Find molecule index for each particle.

Compute clusters of bonded molecules and return an array of the molecule
index of each particle.
Pass in either a gsd file or a snapshot, but not both

Parameters
----------
gsd_file : str,
Filename of the gsd trajectory (default = None)
snap : gsd.hoomd.Snapshot
Trajectory snapshot. (default = None)
gsd_frame : int,
Frame number of gsd_file to use in computing clusters. (default = -1)

Returns
-------
numpy array (N_particles,)
"""
snap = _validate_inputs(gsd_file, snap, gsd_frame)
system = freud.AABBQuery.from_system(snap)
n_query_points = n_points = snap.particles.N
query_point_indices = snap.bonds.group[:, 0]
point_indices = snap.bonds.group[:, 1]
distances = system.box.compute_distances(
system.points[query_point_indices], system.points[point_indices]
)
nlist = freud.NeighborList.from_arrays(
n_query_points, n_points, query_point_indices, point_indices, distances
)
cluster = freud.cluster.Cluster()
cluster.compute(system=system, neighbors=nlist)
return cluster


def _validate_inputs(gsd_file, snap, gsd_frame):
if all([gsd_file, snap]):
raise ValueError("Only pass in one of snapshot, gsd_file")
if gsd_file:
assert isinstance(gsd_frame, int)
try:
with gsd.hoomd.open(name=gsd_file, mode='rb') as f:
snap = f[gsd_frame]
except Exception as e:
print("Unable to open the gsd_file")
raise e
elif snap:
assert isinstance(snap, gsd.hoomd.Snapshot)
return snap
28 changes: 22 additions & 6 deletions cmeutils/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,35 @@ def test_gsd(self, tmp_path):
create_gsd(filename)
return filename

@pytest.fixture
def test_gsd_bonded(self, tmp_path):
filename = tmp_path / "test.gsd"
create_gsd(filename, add_bonds=True)
return filename

@pytest.fixture
def test_snap(self, test_gsd):
with gsd.hoomd.open(name=test_gsd, mode="rb") as f:
snap = f[-1]
return snap

def create_frame(i, seed=42):
def create_frame(i, add_bonds, seed=42):
np.random.seed(seed)
s = gsd.hoomd.Snapshot()
s.configuration.step = i
s.particles.N = 4
s.particles.N = 5
s.particles.types = ['A', 'B']
s.particles.typeid = [0,0,1,1]
s.particles.position = np.random.random(size=(4,3))
s.particles.typeid = [0,0,1,1,1]
s.particles.position = np.random.random(size=(5,3))
s.configuration.box = [3, 3, 3, 0, 0, 0]
if add_bonds:
s.bonds.N = 2
s.bonds.types = ['AB']
s.bonds.typeid = [0, 0]
s.bonds.group = [[0, 2], [1, 3]]
s.validate()
return s

def create_gsd(filename):
def create_gsd(filename, add_bonds=False):
with gsd.hoomd.open(name=filename, mode='wb') as f:
f.extend((create_frame(i) for i in range(10)))
f.extend((create_frame(i, add_bonds=add_bonds) for i in range(10)))
26 changes: 24 additions & 2 deletions cmeutils/tests/test_gsd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import numpy as np
import pytest

from cmeutils import gsd_utils
from base_test import BaseTest


class TestGSD(BaseTest):
def test_frame_get_type_position(self, test_gsd):
pos_array = gsd_utils.frame_get_type_position(test_gsd, 'A')

def test_get_type_position(self, test_gsd):
from cmeutils.gsd_utils import get_type_position

pos_array = get_type_position(gsd_file = test_gsd, type_name = 'A')
assert type(pos_array) is type(np.array([]))
assert pos_array.shape == (2,3)

def test_validate_inputs(self, test_gsd, test_snap):
# Catch error with both gsd_file and snap are passed
with pytest.raises(ValueError):
snap = gsd_utils._validate_inputs(test_gsd, test_snap, 1)
with pytest.raises(AssertionError):
snap = gsd_utils._validate_inputs(test_gsd, None, 1.0)
with pytest.raises(AssertionError):
snap = gsd_utils._validate_inputs(None, test_gsd, 1)
with pytest.raises(OSError):
gsd_utils._validate_inputs("bad_gsd_file", None, 0)

def test_get_all_types(self, test_gsd):
types = gsd_utils.get_all_types(test_gsd)
assert types == ['A', 'B']

def test_snap_molecule_cluster(self, test_gsd_bonded):
cluster = gsd_utils.snap_molecule_cluster(gsd_file=test_gsd_bonded)