Skip to content

Add create_test_data to public testing API #2690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ Testing
testing.assert_equal
testing.assert_identical
testing.assert_allclose
testing.create_test_data

Exceptions
==========
Expand Down
6 changes: 6 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ Enhancements
report showing what exactly differs between the two objects (dimensions /
coordinates / variables / attributes) (:issue:`1507`).
By `Benoit Bovy <https://github.com/benbovy>`_.
- The function :py:func:`xarray.testing.create_test_data` has been added to the
public testing API. It creates a small example dataset, and is provided as a
convenience function for when you are writing tests for code which acts on
xarray objects (:issue:`2686`).
By `Tom Nicholas <http://github.com/TomNicholas>`_.


Bug fixes
~~~~~~~~~
Expand Down
62 changes: 57 additions & 5 deletions xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from __future__ import absolute_import, division, print_function

import numpy as np
from pandas import date_range

from xarray.core import duck_array_ops
from xarray.core import formatting
from xarray import Dataset


def _decode_string_data(data):
Expand Down Expand Up @@ -145,8 +147,58 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
.format(type(a)))


def assert_combined_tile_ids_equal(dict1, dict2):
assert len(dict1) == len(dict2)
for k, v in dict1.items():
assert k in dict2.keys()
assert_equal(dict1[k], dict2[k])
def create_test_data(seed=None):
"""
Creates an example dataset for use when testing functions which act on
xarray objects.

The dataset returned covers several possible edge cases, including
dimensions with and without coordinates, and datetime, integer and string
coordinate values.

Used extensively within xarray's own test suite.

Parameters
----------
seed : int, optional
Seed to use for random data (passed to `np.random.RandomState`),
default is `None`.

Returns
-------
dataset

Examples
--------
>>> xr.create_test_data(seed=0)
<xarray.Dataset>
Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20)
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
* dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
* dim3 (dim3) <U1 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3
Dimensions without coordinates: dim1
Data variables:
var1 (dim1, dim2) float64 1.764 0.4002 0.9787 ...
var2 (dim1, dim2) float64 1.139 -1.235 0.4023 ...
var3 (dim3, dim1) float64 2.383 0.9445 -0.9128 ...
"""
rs = np.random.RandomState(seed)
_vars = {'var1': ['dim1', 'dim2'],
'var2': ['dim1', 'dim2'],
'var3': ['dim3', 'dim1']}
_dims = {'dim1': 8, 'dim2': 9, 'dim3': 10}

obj = Dataset()
obj['time'] = ('time', date_range('2000-01-01', periods=20))
obj['dim2'] = ('dim2', 0.5 * np.arange(_dims['dim2']))
obj['dim3'] = ('dim3', list('abcdefghij'))
for v, dims in sorted(_vars.items()):
data = rs.normal(size=tuple(_dims[d] for d in dims))
obj[v] = (dims, data, {'foo': 'variable'})
obj.coords['numbers'] = ('dim3', np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3],
dtype='int64'))
obj.encoding = {'foo': 'bar'}
assert all(obj.data.flags.writeable for obj in obj.variables.values())
return obj
2 changes: 1 addition & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from xarray.core.options import set_options
from xarray.core.indexing import ExplicitlyIndexed
from xarray.testing import (assert_equal, assert_identical, # noqa: F401
assert_allclose, assert_combined_tile_ids_equal)
assert_allclose)
from xarray.plot.utils import import_seaborn

try:
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from xarray.core.pycompat import (
ExitStack, basestring, dask_array_type, iteritems)
from xarray.tests import mock
from xarray.testing import create_test_data


from . import (
assert_allclose, assert_array_equal, assert_equal, assert_identical,
Expand All @@ -37,7 +39,6 @@
requires_pathlib, requires_pseudonetcdf, requires_pydap, requires_pynio,
requires_rasterio, requires_scipy, requires_scipy_or_netCDF4,
requires_zarr)
from .test_dataset import create_test_data

try:
import netCDF4 as nc4
Expand Down
13 changes: 10 additions & 3 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
_check_shape_tile_ids, _combine_nd, _infer_concat_order_from_positions,
_infer_tile_ids_from_nested_list, _new_tile_id)
from xarray.core.pycompat import OrderedDict, iteritems
from xarray.testing import create_test_data

from . import (
InaccessibleArray, assert_array_equal, assert_combined_tile_ids_equal,
assert_equal, assert_identical, raises_regex, requires_dask)
from .test_dataset import create_test_data
InaccessibleArray, assert_array_equal, assert_equal, assert_identical,
raises_regex, requires_dask)


class TestConcatDataset(object):
Expand Down Expand Up @@ -411,6 +411,13 @@ def test_auto_combine_no_concat(self):
assert_identical(expected, actual)


def assert_combined_tile_ids_equal(dict1, dict2):
assert len(dict1) == len(dict2)
for k, v in dict1.items():
assert k in dict2.keys()
assert_equal(dict1[k], dict2[k])


class TestTileIDsFromNestedList(object):
def test_1d(self):
ds = create_test_data
Expand Down
22 changes: 1 addition & 21 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xarray.core.common import full_like
from xarray.core.pycompat import (
OrderedDict, integer_types, iteritems, unicode_type)
from xarray.testing import create_test_data

from . import (
InaccessibleArray, UnexpectedDataAccess, assert_allclose,
Expand All @@ -33,27 +34,6 @@
pass


def create_test_data(seed=None):
rs = np.random.RandomState(seed)
_vars = {'var1': ['dim1', 'dim2'],
'var2': ['dim1', 'dim2'],
'var3': ['dim3', 'dim1']}
_dims = {'dim1': 8, 'dim2': 9, 'dim3': 10}

obj = Dataset()
obj['time'] = ('time', pd.date_range('2000-01-01', periods=20))
obj['dim2'] = ('dim2', 0.5 * np.arange(_dims['dim2']))
obj['dim3'] = ('dim3', list('abcdefghij'))
for v, dims in sorted(_vars.items()):
data = rs.normal(size=tuple(_dims[d] for d in dims))
obj[v] = (dims, data, {'foo': 'variable'})
obj.coords['numbers'] = ('dim3', np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3],
dtype='int64'))
obj.encoding = {'foo': 'bar'}
assert all(obj.data.flags.writeable for obj in obj.variables.values())
return obj


def create_test_multiindex():
mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]],
names=('level_1', 'level_2'))
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file,
create_tmp_geotiff,
open_example_dataset)
from xarray.tests.test_dataset import create_test_data
from xarray.testing import create_test_data

from . import (
assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy,
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import xarray as xr
from xarray.tests import (
assert_allclose, assert_equal, requires_cftime, requires_scipy)
from xarray.testing import create_test_data

from . import has_dask, has_scipy
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from .test_dataset import create_test_data

try:
import scipy
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import xarray as xr
from xarray.core import merge
from xarray.testing import create_test_data

from . import raises_regex
from .test_dataset import create_test_data


class TestMergeInternals(object):
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xarray import concat, merge
from xarray.backends.file_manager import FILE_CACHE
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.tests.test_dataset import create_test_data
from xarray.testing import create_test_data


def test_invalid_option_raises():
Expand Down