diff --git a/src/python/nimbusml/internal/utils/data_stream.py b/src/python/nimbusml/internal/utils/data_stream.py index ea544307..165c77f4 100644 --- a/src/python/nimbusml/internal/utils/data_stream.py +++ b/src/python/nimbusml/internal/utils/data_stream.py @@ -8,6 +8,7 @@ import os import tempfile from shutil import copyfile +from pathlib import Path from .data_roles import DataRoles from .data_schema import DataSchema @@ -229,6 +230,10 @@ def __init__(self, filename, schema, roles=None): :param schema: filename schema """ super(FileDataStream, self).__init__(schema, roles) + + if isinstance(filename, Path): + filename = str(filename.resolve()) + self._filename = filename def __repr__(self): diff --git a/src/python/nimbusml/tests/test_data_stream.py b/src/python/nimbusml/tests/test_data_stream.py index 744c1854..c6570a91 100644 --- a/src/python/nimbusml/tests/test_data_stream.py +++ b/src/python/nimbusml/tests/test_data_stream.py @@ -10,6 +10,7 @@ import pandas from nimbusml import DataSchema from nimbusml import FileDataStream +from pathlib import Path try: from pandas.testing import assert_frame_equal @@ -30,6 +31,17 @@ def test_data_stream(self): assert repr(fi) == repr(fi2) os.remove(f.name) + def test_data_stream_path_object(self): + df = pandas.DataFrame(dict(a=[0, 1], b=[0.1, 0.2])) + with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + df.to_csv(f, sep=',', index=False) + + fi = FileDataStream.read_csv(Path(f.name), sep=',') + fi2 = fi.clone() + assert repr(fi) == repr(fi2) + os.remove(f.name) + + def test_data_header_no_dataframe(self): li = [1.0, 1.0, 2.0] df = pandas.DataFrame(li)