Skip to content

Commit 537d684

Browse files
seismanweiji14
andauthored
Wrap GMT's standard data type GMT_IMAGE for images (#3338)
Co-authored-by: Wei Ji <[email protected]>
1 parent ff246c6 commit 537d684

File tree

4 files changed

+149
-11
lines changed

4 files changed

+149
-11
lines changed

pygmt/clib/session.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
vectors_to_arrays,
2727
)
2828
from pygmt.clib.loading import load_libgmt
29-
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID
29+
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
3030
from pygmt.exceptions import (
3131
GMTCLibError,
3232
GMTCLibNoSessionError,
@@ -1071,7 +1071,7 @@ def put_matrix(self, dataset, matrix, pad=0):
10711071
def read_data(
10721072
self,
10731073
infile: str,
1074-
kind: Literal["dataset", "grid"],
1074+
kind: Literal["dataset", "grid", "image"],
10751075
family: str | None = None,
10761076
geometry: str | None = None,
10771077
mode: str = "GMT_READ_NORMAL",
@@ -1089,8 +1089,8 @@ def read_data(
10891089
infile
10901090
The input file name.
10911091
kind
1092-
The data kind of the input file. Valid values are ``"dataset"`` and
1093-
``"grid"``.
1092+
The data kind of the input file. Valid values are ``"dataset"``, ``"grid"``
1093+
and ``"image"``.
10941094
family
10951095
A valid GMT data family name (e.g., ``"GMT_IS_DATASET"``). See the
10961096
``FAMILIES`` attribute for valid names. If ``None``, will determine the data
@@ -1141,6 +1141,7 @@ def read_data(
11411141
_family, _geometry, dtype = {
11421142
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP", _GMT_DATASET),
11431143
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE", _GMT_GRID),
1144+
"image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE", _GMT_IMAGE),
11441145
}[kind]
11451146
if family is None:
11461147
family = _family
@@ -1797,7 +1798,9 @@ def virtualfile_from_data(
17971798

17981799
@contextlib.contextmanager
17991800
def virtualfile_out(
1800-
self, kind: Literal["dataset", "grid"] = "dataset", fname: str | None = None
1801+
self,
1802+
kind: Literal["dataset", "grid", "image"] = "dataset",
1803+
fname: str | None = None,
18011804
) -> Generator[str, None, None]:
18021805
r"""
18031806
Create a virtual file or an actual file for storing output data.
@@ -1810,8 +1813,8 @@ def virtualfile_out(
18101813
Parameters
18111814
----------
18121815
kind
1813-
The data kind of the virtual file to create. Valid values are ``"dataset"``
1814-
and ``"grid"``. Ignored if ``fname`` is specified.
1816+
The data kind of the virtual file to create. Valid values are ``"dataset"``,
1817+
``"grid"``, and ``"image"``. Ignored if ``fname`` is specified.
18151818
fname
18161819
The name of the actual file to write the output data. No virtual file will
18171820
be created.
@@ -1854,8 +1857,10 @@ def virtualfile_out(
18541857
family, geometry = {
18551858
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP"),
18561859
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE"),
1860+
"image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE"),
18571861
}[kind]
1858-
with self.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
1862+
direction = "GMT_OUT|GMT_IS_REFERENCE" if kind == "image" else "GMT_OUT"
1863+
with self.open_virtualfile(family, geometry, direction, None) as vfile:
18591864
yield vfile
18601865

18611866
def inquire_virtualfile(self, vfname: str) -> int:
@@ -1901,7 +1906,8 @@ def read_virtualfile(
19011906
Name of the virtual file to read.
19021907
kind
19031908
Cast the data into a GMT data container. Valid values are ``"dataset"``,
1904-
``"grid"`` and ``None``. If ``None``, will return a ctypes void pointer.
1909+
``"grid"``, ``"image"`` and ``None``. If ``None``, will return a ctypes void
1910+
pointer.
19051911
19061912
Returns
19071913
-------
@@ -1951,9 +1957,9 @@ def read_virtualfile(
19511957
# _GMT_DATASET).
19521958
if kind is None: # Return the ctypes void pointer
19531959
return pointer
1954-
if kind in {"image", "cube"}:
1960+
if kind == "cube":
19551961
raise NotImplementedError(f"kind={kind} is not supported yet.")
1956-
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind]
1962+
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID, "image": _GMT_IMAGE}[kind]
19571963
return ctp.cast(pointer, ctp.POINTER(dtype))
19581964

19591965
def virtualfile_to_dataset(

pygmt/datatypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from pygmt.datatypes.dataset import _GMT_DATASET
66
from pygmt.datatypes.grid import _GMT_GRID
7+
from pygmt.datatypes.image import _GMT_IMAGE

pygmt/datatypes/image.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
Wrapper for the GMT_IMAGE data type.
3+
"""
4+
5+
import ctypes as ctp
6+
from typing import ClassVar
7+
8+
from pygmt.datatypes.grid import _GMT_GRID_HEADER
9+
10+
11+
class _GMT_IMAGE(ctp.Structure): # noqa: N801
12+
"""
13+
GMT image data structure.
14+
15+
Examples
16+
--------
17+
>>> import numpy as np
18+
>>> from pygmt.clib import Session
19+
>>> with Session() as lib:
20+
... with lib.virtualfile_out(kind="image") as voutimg:
21+
... lib.call_module("read", ["@earth_day_01d", voutimg, "-Ti"])
22+
... # Read the image from the virtual file
23+
... image = lib.read_virtualfile(vfname=voutimg, kind="image").contents
24+
... # The image header
25+
... header = image.header.contents
26+
... # Access the header properties
27+
... print(header.n_rows, header.n_columns, header.registration)
28+
... print(header.wesn[:], header.inc[:])
29+
... print(header.z_scale_factor, header.z_add_offset)
30+
... print(header.x_units, header.y_units, header.z_units)
31+
... print(header.title)
32+
... print(header.command)
33+
... print(header.remark)
34+
... print(header.nm, header.size, header.complex_mode)
35+
... print(header.type, header.n_bands, header.mx, header.my)
36+
... print(header.pad[:])
37+
... print(header.mem_layout, header.nan_value, header.xy_off)
38+
... # Image-specific attributes.
39+
... print(image.type, image.n_indexed_colors)
40+
... # The x and y coordinates
41+
... x = image.x[: header.n_columns]
42+
... y = image.y[: header.n_rows]
43+
... # The data array (with paddings)
44+
... data = np.reshape(
45+
... image.data[: header.n_bands * header.mx * header.my],
46+
... (header.my, header.mx, header.n_bands),
47+
... )
48+
... # The data array (without paddings)
49+
... pad = header.pad[:]
50+
... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]
51+
180 360 1
52+
[-180.0, 180.0, -90.0, 90.0] [1.0, 1.0]
53+
1.0 0.0
54+
b'x' b'y' b'z'
55+
b''
56+
b''
57+
b''
58+
64800 66976 0
59+
0 3 364 184
60+
[2, 2, 2, 2]
61+
b'BRPa' 0.0 0.5
62+
1 0
63+
>>> x
64+
[-179.5, -178.5, ..., 178.5, 179.5]
65+
>>> y
66+
[89.5, 88.5, ..., -88.5, -89.5]
67+
>>> data.shape
68+
(180, 360, 3)
69+
>>> data.min(), data.max()
70+
(10, 255)
71+
"""
72+
73+
_fields_: ClassVar = [
74+
# Data type, e.g. GMT_FLOAT
75+
("type", ctp.c_int),
76+
# Array with color lookup values
77+
("colormap", ctp.POINTER(ctp.c_int)),
78+
# Number of colors in a paletted image
79+
("n_indexed_colors", ctp.c_int),
80+
# Pointer to full GMT header for the image
81+
("header", ctp.POINTER(_GMT_GRID_HEADER)),
82+
# Pointer to actual image
83+
("data", ctp.POINTER(ctp.c_ubyte)),
84+
# Pointer to an optional transparency layer stored in a separate variable
85+
("alpha", ctp.POINTER(ctp.c_ubyte)),
86+
# Color interpolation
87+
("color_interp", ctp.c_char_p),
88+
# Pointer to the x-coordinate vector
89+
("x", ctp.POINTER(ctp.c_double)),
90+
# Pointer to the y-coordinate vector
91+
("y", ctp.POINTER(ctp.c_double)),
92+
# Book-keeping variables "hidden" from the API
93+
("hidden", ctp.c_void_p),
94+
]

pygmt/tests/test_clib_read_data.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,43 @@ def test_clib_read_data_grid_actual_image():
132132
)
133133

134134

135+
# Note: Simplify the tests for images after GMT_IMAGE.to_dataarray() is implemented.
136+
def test_clib_read_data_image():
137+
"""
138+
Test the Session.read_data method for images.
139+
"""
140+
with Session() as lib:
141+
image = lib.read_data("@earth_day_01d_p", kind="image").contents
142+
header = image.header.contents
143+
assert header.n_rows == 180
144+
assert header.n_columns == 360
145+
assert header.n_bands == 3
146+
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
147+
assert image.data
148+
149+
150+
def test_clib_read_data_image_two_steps():
151+
"""
152+
Test the Session.read_data method for images in two steps, first reading the header
153+
and then the data.
154+
"""
155+
infile = "@earth_day_01d_p"
156+
with Session() as lib:
157+
# Read the header first
158+
data_ptr = lib.read_data(infile, kind="image", mode="GMT_CONTAINER_ONLY")
159+
image = data_ptr.contents
160+
header = image.header.contents
161+
assert header.n_rows == 180
162+
assert header.n_columns == 360
163+
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
164+
assert header.n_bands == 3 # Explicitly check n_bands
165+
assert not image.data # The data is not read yet
166+
167+
# Read the data
168+
lib.read_data(infile, kind="image", mode="GMT_DATA_ONLY", data=data_ptr)
169+
assert image.data
170+
171+
135172
def test_clib_read_data_fails():
136173
"""
137174
Test that the Session.read_data method raises an exception if there are errors.

0 commit comments

Comments
 (0)