Skip to content

Commit 8d52ece

Browse files
authored
fix(ParticleData): support partlocs as ndarray or list of lists (#1752)
1 parent c307420 commit 8d52ece

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

autotest/test_particledata.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
3+
from flopy.modpath import ParticleData
4+
5+
structured_plocs = [(1, 1, 1), (1, 1, 2)]
6+
structured_dtype = np.dtype(
7+
[
8+
("k", "<i4"),
9+
("i", "<i4"),
10+
("j", "<i4"),
11+
("localx", "<f4"),
12+
("localy", "<f4"),
13+
("localz", "<f4"),
14+
("timeoffset", "<f4"),
15+
("drape", "<i4"),
16+
]
17+
)
18+
structured_array = np.core.records.fromrecords(
19+
[
20+
(1, 1, 1, 0.5, 0.5, 0.5, 0.0, 0),
21+
(1, 1, 2, 0.5, 0.5, 0.5, 0.0, 0),
22+
],
23+
dtype=structured_dtype,
24+
)
25+
26+
27+
def test_particledata_structured_partlocs_as_list_of_tuples():
28+
locs = structured_plocs
29+
data = ParticleData(partlocs=locs, structured=True)
30+
31+
assert data.particlecount == 2
32+
assert data.dtype == structured_dtype
33+
assert np.array_equal(data.particledata, structured_array)
34+
35+
36+
def test_particledata_structured_partlocs_as_ndarray():
37+
locs = np.array(structured_plocs)
38+
data = ParticleData(partlocs=locs, structured=True)
39+
40+
assert data.particlecount == 2
41+
assert data.dtype == structured_dtype
42+
assert np.array_equal(data.particledata, structured_array)
43+
44+
45+
def test_particledata_structured_partlocs_as_list_of_lists():
46+
locs = [list(p) for p in structured_plocs]
47+
data = ParticleData(partlocs=locs, structured=True)
48+
49+
assert data.particlecount == 2
50+
assert data.dtype == structured_dtype
51+
assert np.array_equal(data.particledata, structured_array)

flopy/modpath/mp7particledata.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import numpy as np
9+
from numpy.lib.recfunctions import unstructured_to_structured
910

1011
from ..utils.recarray_utils import create_empty_recarray
1112

@@ -125,7 +126,7 @@ def __init__(
125126
alllen3 = all(len(el) == 3 for el in partlocs)
126127
if not alllen3:
127128
raise ValueError(
128-
"{}: all partlocs entries must have 3 items for "
129+
"{}: all partlocs entries must have 3 items for "
129130
"structured particle data".format(self.name)
130131
)
131132
else:
@@ -164,14 +165,16 @@ def __init__(
164165

165166
# convert partlocs composed of a lists/tuples of lists/tuples
166167
# to a numpy array
167-
partlocs = np.array(partlocs, dtype=dtype)
168+
partlocs = unstructured_to_structured(
169+
np.array(partlocs), dtype=dtype
170+
)
168171
elif isinstance(partlocs, np.ndarray):
169172
dtypein = partlocs.dtype
170173
if dtypein != dtype:
171-
partlocs = np.array(partlocs, dtype=dtype)
174+
partlocs = unstructured_to_structured(partlocs, dtype=dtype)
172175
else:
173176
raise ValueError(
174-
f"{self.name}: partlocs must be a list or tuple with lists or tuples"
177+
f"{self.name}: partlocs must be a list or tuple with lists or tuples, or an ndarray"
175178
)
176179

177180
# localx

0 commit comments

Comments
 (0)