From 66e824c801ac6e98f256f15bb1b53b37397ed6d0 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 6 Apr 2021 17:51:39 +1200 Subject: [PATCH 1/3] Refactor grdtrack to use virtualfile_from_data Enables `grdtrack` to work with table-like inputs besides pandas.DataFrame. Also, the `outfile` parameter has become optional, and the output data will be loaded as a pandas.DataFrame when `outfile` is unset. Unit tests have been updated accordingly too. --- pygmt/src/grdtrack.py | 23 +++++++++-------------- pygmt/tests/test_grdtrack.py | 9 ++++++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index eaa1f0e13cc..2a75bc2489f 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -68,21 +68,13 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): - None if ``outfile`` is set (track output will be stored in file set by ``outfile``) """ + if data_kind(points) == "matrix" and newcolname is None: + raise GMTInvalidInput("Please pass in a str to 'newcolname'") with GMTTempFile(suffix=".csv") as tmpfile: with Session() as lib: - # Store the pandas.DataFrame points table in virtualfile - if data_kind(points) == "matrix": - if newcolname is None: - raise GMTInvalidInput("Please pass in a str to 'newcolname'") - table_context = lib.virtualfile_from_matrix(points.values) - elif data_kind(points) == "file": - if outfile is None: - raise GMTInvalidInput("Please pass in a str to 'outfile'") - table_context = dummy_context(points) - else: - raise GMTInvalidInput(f"Unrecognized data type {type(points)}") - + # Choose how data will be passed into the module + table_context = lib.virtualfile_from_data(check_kind="vector", data=points) # Store the xarray.DataArray grid in virtualfile grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid) @@ -100,8 +92,11 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): # Read temporary csv output to a pandas table if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - column_names = points.columns.to_list() + [newcolname] - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + try: + column_names = points.columns.to_list() + [newcolname] + result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + except AttributeError: # 'str' object has no attribute 'columns' + result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") elif outfile != tmpfile.name: # return None if outfile set, output in outfile result = None diff --git a/pygmt/tests/test_grdtrack.py b/pygmt/tests/test_grdtrack.py index 8ed74adcd47..44bae76709f 100644 --- a/pygmt/tests/test_grdtrack.py +++ b/pygmt/tests/test_grdtrack.py @@ -132,11 +132,14 @@ def test_grdtrack_without_newcolname_setting(dataarray): grdtrack(points=dataframe, grid=dataarray) -def test_grdtrack_without_outfile_setting(dataarray): +def test_grdtrack_without_outfile_setting(): """ Run grdtrack by not passing in outfile parameter setting. """ csvfile = which("@ridge.txt", download="c") + ncfile = which("@earth_relief_01d", download="a") - with pytest.raises(GMTInvalidInput): - grdtrack(points=csvfile, grid=dataarray) + output = grdtrack(points=csvfile, grid=ncfile) + npt.assert_allclose(output.iloc[0], [-32.2971, 37.4118, -1939.748245]) + + return output From 95d7c4b2e6355e62f5e73e676b0baf20d9a7693b Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 6 Apr 2021 18:08:27 +1200 Subject: [PATCH 2/3] State that outfile is not required anymore, and fix a lint error --- pygmt/src/grdtrack.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index 2a75bc2489f..366506a5922 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -8,7 +8,6 @@ GMTTempFile, build_arg_string, data_kind, - dummy_context, fmt_docstring, use_alias, ) @@ -51,8 +50,7 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): sampled values will be placed. outfile : str - Required if ``points`` is a file. The file name for the output ASCII - file. + The file name for the output ASCII file. {V} {f} From b375a36fb24fbd55321101214ff5240bea4d9cda Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Thu, 29 Apr 2021 15:53:01 +1200 Subject: [PATCH 3/3] Update expected grdtrack outputs to see if the test xpasses --- pygmt/tests/test_grdtrack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pygmt/tests/test_grdtrack.py b/pygmt/tests/test_grdtrack.py index 44bae76709f..60c728c53cd 100644 --- a/pygmt/tests/test_grdtrack.py +++ b/pygmt/tests/test_grdtrack.py @@ -36,7 +36,7 @@ def test_grdtrack_input_dataframe_and_dataarray(dataarray): output = grdtrack(points=dataframe, grid=dataarray, newcolname="bathymetry") assert isinstance(output, pd.DataFrame) assert output.columns.to_list() == ["longitude", "latitude", "bathymetry"] - npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2797.394987]) + npt.assert_allclose(output.iloc[0], [-110.9536, -42.2489, -2974.656296]) return output @@ -54,7 +54,7 @@ def test_grdtrack_input_csvfile_and_dataarray(dataarray): assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">") - npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2797.394987]) + npt.assert_allclose(track.iloc[0], [-110.9536, -42.2489, -2974.656296]) finally: os.remove(path=TEMP_TRACK)