Skip to content

Allow pandas.DataFrame table and 1D/2D numpy array inputs into pygmt.info #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions pygmt/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Non-plot GMT modules.
"""
import numpy as np
import xarray as xr

from .clib import Session
Expand Down Expand Up @@ -55,7 +56,7 @@ def grdinfo(grid, **kwargs):

@fmt_docstring
@use_alias(C="per_column", I="spacing", T="nearest_multiple")
def info(fname, **kwargs):
def info(table, **kwargs):
"""
Get information about data tables.

Expand All @@ -74,8 +75,9 @@ def info(fname, **kwargs):

Parameters
----------
fname : str
The file name of the input data table file.
table : pandas.DataFrame or np.ndarray or str
Either a pandas dataframe, a 1D/2D numpy.ndarray or a file name to an
ASCII data table.
per_column : bool
Report the min/max values per column in separate columns.
spacing : str
Expand All @@ -88,14 +90,25 @@ def info(fname, **kwargs):
Report the min/max of the first (0'th) column to the nearest multiple
of dz and output this as the string *-Tzmin/zmax/dz*.
"""
if not isinstance(fname, str):
raise GMTInvalidInput("'info' only accepts file names.")
kind = data_kind(table)
with Session() as lib:
if kind == "file":
file_context = dummy_context(table)
elif kind == "matrix":
_table = np.asanyarray(table)
Copy link
Member Author

@weiji14 weiji14 Sep 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, this line converts pandas.DataFrame into a numpy.ndarray, while allowing numpy.ndarray objects to just pass through. A bit easier than using _table = table.values if hasattr(table, values) else table, and I think this allows for Python lists to work too (though I haven't tested it).

See also:

if table.ndim == 1: # 1D arrays need to be 2D and transposed
_table = np.transpose(np.atleast_2d(_table))
file_context = lib.virtualfile_from_matrix(_table)
else:
raise GMTInvalidInput(f"Unrecognized data type: {type(table)}")

with GMTTempFile() as tmpfile:
arg_str = " ".join([fname, build_arg_string(kwargs), "->" + tmpfile.name])
with Session() as lib:
lib.call_module("info", arg_str)
return tmpfile.read()
with GMTTempFile() as tmpfile:
with file_context as fname:
arg_str = " ".join(
[fname, build_arg_string(kwargs), "->" + tmpfile.name]
)
lib.call_module("info", arg_str)
return tmpfile.read()


@fmt_docstring
Expand Down
50 changes: 40 additions & 10 deletions pygmt/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from .. import info
from ..exceptions import GMTInvalidInput
Expand All @@ -14,8 +16,8 @@


def test_info():
"Make sure info works"
output = info(fname=POINTS_DATA)
"Make sure info works on file name inputs"
output = info(table=POINTS_DATA)
expected_output = (
f"{POINTS_DATA}: N = 20 "
"<11.5309/61.7074> "
Expand All @@ -25,33 +27,61 @@ def test_info():
assert output == expected_output


def test_info_dataframe():
"Make sure info works on pandas.DataFrame inputs"
table = pd.read_csv(POINTS_DATA, sep=" ", header=None)
output = info(table=table)
expected_output = (
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
)
assert output == expected_output


def test_info_2d_array():
"Make sure info works on 2D numpy.ndarray inputs"
table = np.loadtxt(POINTS_DATA)
output = info(table=table)
expected_output = (
"<matrix memory>: N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n"
)
assert output == expected_output


def test_info_1d_array():
"Make sure info works on 1D numpy.ndarray inputs"
output = info(table=np.arange(20))
expected_output = "<matrix memory>: N = 20 <0/19>\n"
assert output == expected_output


def test_info_per_column():
"Make sure the per_column option works"
output = info(fname=POINTS_DATA, per_column=True)
output = info(table=POINTS_DATA, per_column=True)
assert output == "11.5309 61.7074 -2.9289 7.8648 0.1412 0.9338\n"


def test_info_spacing():
"Make sure the spacing option works"
output = info(fname=POINTS_DATA, spacing=0.1)
output = info(table=POINTS_DATA, spacing=0.1)
assert output == "-R11.5/61.8/-3/7.9\n"


def test_info_per_column_spacing():
"Make sure the per_column and spacing options work together"
output = info(fname=POINTS_DATA, per_column=True, spacing=0.1)
output = info(table=POINTS_DATA, per_column=True, spacing=0.1)
assert output == "11.5 61.8 -3 7.9 0.1412 0.9338\n"


def test_info_nearest_multiple():
"Make sure the nearest_multiple option works"
output = info(fname=POINTS_DATA, nearest_multiple=0.1)
output = info(table=POINTS_DATA, nearest_multiple=0.1)
assert output == "-T11.5/61.8/0.1\n"


def test_info_fails():
"Make sure info raises an exception if not given a file name"
with pytest.raises(GMTInvalidInput):
info(fname=21)
"""
Make sure info raises an exception if not given either a file name, pandas
DataFrame, or numpy ndarray
"""
with pytest.raises(GMTInvalidInput):
info(fname=np.arange(20))
info(table=xr.DataArray(21))