Skip to content

Commit d379e46

Browse files
committed
Enable passing pyarrow.StringArray to clib.Session.put_strings
Convert a pyarrow.StringArray via a Python list to a ctypes array in the strings_to_ctypes_array function. Updated docstrings and type hints in `clib.Session.put_strings` method and `clib.conversion.strings_to_ctypes_array` function. Added two parametrized unit tests to ensure that pyarrow.StringArray can be passed into the clib methods.
1 parent 07fbca6 commit d379e46

File tree

4 files changed

+63
-25
lines changed

4 files changed

+63
-25
lines changed

pygmt/clib/conversion.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import numpy as np
1010
from pygmt.exceptions import GMTInvalidInput
1111

12+
try:
13+
import pyarrow as pa
14+
except ImportError:
15+
pa = None
1216

1317
def dataarray_to_matrix(grid):
1418
"""
@@ -263,14 +267,15 @@ def sequence_to_ctypes_array(
263267
return (ctype * size)(*sequence)
264268

265269

266-
def strings_to_ctypes_array(strings: Sequence[str]) -> ctp.Array:
270+
def strings_to_ctypes_array(strings: Sequence[str] | pa.StringArray) -> ctp.Array:
267271
"""
268-
Convert a sequence (e.g., a list) of strings into a ctypes array.
272+
Convert a sequence (e.g., a list) of strings or a pyarrow.StringArray into a ctypes
273+
array.
269274
270275
Parameters
271276
----------
272277
strings
273-
A sequence of strings.
278+
A sequence of strings or a pyarrow.StringArray.
274279
275280
Returns
276281
-------
@@ -286,7 +291,12 @@ def strings_to_ctypes_array(strings: Sequence[str]) -> ctp.Array:
286291
>>> [s.decode() for s in ctypes_array]
287292
['first', 'second', 'third']
288293
"""
289-
return (ctp.c_char_p * len(strings))(*[s.encode() for s in strings])
294+
try:
295+
bytes_string_list = [s.encode() for s in strings]
296+
except AttributeError: # 'pyarrow.StringScalar' object has no attribute 'encode'
297+
# Convert pyarrow.StringArray to Python list first
298+
bytes_string_list = [s.encode() for s in strings.to_pylist()]
299+
return (ctp.c_char_p * len(strings))(*bytes_string_list)
290300

291301

292302
def array_to_datetime(array):

pygmt/clib/session.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@
3434
tempfile_from_image,
3535
)
3636

37+
try:
38+
import pyarrow as pa
39+
except ImportError:
40+
pa = None
41+
3742
FAMILIES = [
3843
"GMT_IS_DATASET", # Entity is a data table
3944
"GMT_IS_GRID", # Entity is a grid
@@ -936,39 +941,43 @@ def put_vector(self, dataset, column, vector):
936941
f"in column {column} of dataset."
937942
)
938943

939-
def put_strings(self, dataset, family, strings):
944+
def put_strings(
945+
self,
946+
dataset: ctp.c_void_p,
947+
family: Literal["GMT_IS_VECTOR", "GMT_IS_MATRIX"],
948+
strings: Sequence[str] | pa.StringArray,
949+
):
940950
"""
941-
Attach a numpy 1-D array of dtype str as a column on a GMT dataset.
951+
Attach a 1-D numpy array of dtype str or pyarrow.StringArray as a column on a
952+
GMT dataset.
942953
943-
Use this function to attach string type numpy array data to a GMT
944-
dataset and pass it to GMT modules. Wraps ``GMT_Put_Strings``.
954+
Use this function to attach string type array data to a GMT dataset and pass it
955+
to GMT modules. Wraps ``GMT_Put_Strings``.
945956
946-
The dataset must be created by :meth:`pygmt.clib.Session.create_data`
947-
first.
957+
The dataset must be created by :meth:`pygmt.clib.Session.create_data` first.
948958
949959
.. warning::
950-
The numpy array must be C contiguous in memory. If it comes from a
951-
column slice of a 2-D array, for example, you will have to make a
952-
copy. Use :func:`numpy.ascontiguousarray` to make sure your vector
953-
is contiguous (it won't copy if it already is).
960+
The array must be C contiguous in memory. If it comes from a column slice of
961+
a 2-D array, for example, you will have to make a copy. Use
962+
:func:`numpy.ascontiguousarray` to make sure your vector is contiguous (it
963+
won't copy if it already is).
954964
955965
Parameters
956966
----------
957-
dataset : :class:`ctypes.c_void_p`
967+
dataset
958968
The ctypes void pointer to a ``GMT_Dataset``. Create it with
959969
:meth:`pygmt.clib.Session.create_data`.
960-
family : str
970+
family
961971
The family type of the dataset. Can be either ``GMT_IS_VECTOR`` or
962972
``GMT_IS_MATRIX``.
963-
strings : numpy 1-D array
964-
The array that will be attached to the dataset. Must be a 1-D C
965-
contiguous array.
973+
strings
974+
The array that will be attached to the dataset. Must be a 1-D C contiguous
975+
array.
966976
967977
Raises
968978
------
969979
GMTCLibError
970-
If given invalid input or ``GMT_Put_Strings`` exits with
971-
status != 0.
980+
If given invalid input or ``GMT_Put_Strings`` exits with status != 0.
972981
"""
973982
c_put_strings = self.get_libgmt_func(
974983
"GMT_Put_Strings",

pygmt/tests/test_clib_put_strings.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,31 @@
88
from pygmt import clib
99
from pygmt.exceptions import GMTCLibError
1010
from pygmt.helpers import GMTTempFile
11+
from pygmt.helpers.testing import skip_if_no
12+
13+
try:
14+
import pyarrow as pa
15+
except ImportError:
16+
pa = None
1117

1218

1319
@pytest.mark.benchmark
14-
def test_put_strings():
20+
@pytest.mark.parametrize(
21+
("array_func", "dtype"),
22+
[
23+
pytest.param(np.array, {"dtype": str}, id="str"),
24+
pytest.param(
25+
getattr(pa, "array", None),
26+
{"type": pa.string()},
27+
marks=skip_if_no(package="pyarrow"),
28+
id="pyarrow",
29+
),
30+
],
31+
)
32+
def test_put_strings(array_func, dtype):
1533
"""
16-
Check that assigning a numpy array of dtype str to a dataset works.
34+
Check that assigning a numpy array of dtype str, or a pyarrow.StringArray to a
35+
dataset works.
1736
"""
1837
with clib.Session() as lib:
1938
dataset = lib.create_data(
@@ -24,7 +43,7 @@ def test_put_strings():
2443
)
2544
x = np.array([1, 2, 3, 4, 5], dtype=np.int32)
2645
y = np.array([6, 7, 8, 9, 10], dtype=np.int32)
27-
strings = np.array(["a", "bc", "defg", "hijklmn", "opqrst"], dtype=str)
46+
strings = array_func(["a", "bc", "defg", "hijklmn", "opqrst"], **dtype)
2847
lib.put_vector(dataset, column=lib["GMT_X"], vector=x)
2948
lib.put_vector(dataset, column=lib["GMT_Y"], vector=y)
3049
lib.put_strings(

pygmt/tests/test_clib_virtualfiles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_virtualfile_from_vectors(dtypes):
238238
pytest.param(np.array, {"dtype": object}, id="object"),
239239
pytest.param(
240240
getattr(pa, "array", None),
241-
{}, # pa.string()
241+
{"type": pa.string()},
242242
marks=skip_if_no(package="pyarrow"),
243243
id="pyarrow",
244244
),

0 commit comments

Comments
 (0)