diff --git a/dvc/exceptions.py b/dvc/exceptions.py index 32545c5f75..a886bb965e 100644 --- a/dvc/exceptions.py +++ b/dvc/exceptions.py @@ -222,9 +222,12 @@ def __init__(self, etag, cached_etag): class FileMissingError(DvcException): - def __init__(self, path): + def __init__(self, path, hint=None): self.path = path - super().__init__(f"Can't find '{path}' neither locally nor on remote") + hint = "" if hint is None else f". {hint}" + super().__init__( + f"Can't find '{path}' neither locally nor on remote{hint}" + ) class DvcIgnoreInCollectedDirError(DvcException): diff --git a/dvc/remote/gdrive.py b/dvc/remote/gdrive.py index 0204a8997e..d983c1eed9 100644 --- a/dvc/remote/gdrive.py +++ b/dvc/remote/gdrive.py @@ -1,38 +1,35 @@ +import io import logging import os import posixpath import re import threading from collections import defaultdict +from contextlib import contextmanager from urllib.parse import urlparse from funcy import cached_property, retry, wrap_prop, wrap_with from funcy.py3 import cat -from dvc.exceptions import DvcException +from dvc.exceptions import DvcException, FileMissingError from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm from dvc.remote.base import BaseRemote, BaseRemoteTree from dvc.scheme import Schemes from dvc.utils import format_link, tmp_fname +from dvc.utils.stream import IterStream logger = logging.getLogger(__name__) FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" -class GDrivePathNotFound(DvcException): - def __init__(self, path_info, hint): - hint = "" if hint is None else f" {hint}" - super().__init__(f"GDrive path '{path_info}' not found.{hint}") - - class GDriveAuthError(DvcException): def __init__(self, cred_location): if cred_location: message = ( "GDrive remote auth failed with credentials in '{}'.\n" - "Backup first, remove of fix them, and run DVC again.\n" + "Backup first, remove or fix them, and run DVC again.\n" "It should do auth again and refresh the credentials.\n\n" "Details:".format(cred_location) ) @@ -389,6 +386,23 @@ def _gdrive_download_file( ) as pbar: gdrive_file.GetContentFile(to_file, callback=pbar.update_to) + @contextmanager + @_gdrive_retry + def open(self, path_info, mode="r", encoding=None): + assert mode in {"r", "rt", "rb"} + + item_id = self._get_item_id(path_info) + param = {"id": item_id} + # it does not create a file on the remote + gdrive_file = self._drive.CreateFile(param) + fd = gdrive_file.GetContentIOBuffer() + stream = IterStream(iter(fd)) + + if mode != "rb": + stream = io.TextIOWrapper(stream, encoding=encoding) + + yield stream + @_gdrive_retry def gdrive_delete_file(self, item_id): from pydrive2.files import ApiRequestError @@ -502,12 +516,12 @@ def _get_item_id(self, path_info, create=False, use_cache=True, hint=None): return min(item_ids) assert not create - raise GDrivePathNotFound(path_info, hint) + raise FileMissingError(path_info, hint) def exists(self, path_info): try: self._get_item_id(path_info) - except GDrivePathNotFound: + except FileMissingError: return False else: return True diff --git a/dvc/utils/http.py b/dvc/utils/http.py index b1fb13cb79..4472b80912 100644 --- a/dvc/utils/http.py +++ b/dvc/utils/http.py @@ -1,6 +1,8 @@ import io from contextlib import contextmanager +from dvc.utils.stream import IterStream + @contextmanager def open_url(url, mode="r", encoding=None): @@ -61,47 +63,3 @@ def gen(response): finally: # Ensure connection is closed it.close() - - -class IterStream(io.RawIOBase): - """Wraps an iterator yielding bytes as a file object""" - - def __init__(self, iterator): - self.iterator = iterator - self.leftover = None - - def readable(self): - return True - - # Python 3 requires only .readinto() method, it still uses other ones - # under some circumstances and falls back if those are absent. Since - # iterator already constructs byte strings for us, .readinto() is not the - # most optimal, so we provide .read1() too. - - def readinto(self, b): - try: - n = len(b) # We're supposed to return at most this much - chunk = self.leftover or next(self.iterator) - output, self.leftover = chunk[:n], chunk[n:] - - n_out = len(output) - b[:n_out] = output - return n_out - except StopIteration: - return 0 # indicate EOF - - readinto1 = readinto - - def read1(self, n=-1): - try: - chunk = self.leftover or next(self.iterator) - except StopIteration: - return b"" - - # Return an arbitrary number or bytes - if n <= 0: - self.leftover = None - return chunk - - output, self.leftover = chunk[:n], chunk[n:] - return output diff --git a/dvc/utils/stream.py b/dvc/utils/stream.py new file mode 100644 index 0000000000..6109475030 --- /dev/null +++ b/dvc/utils/stream.py @@ -0,0 +1,45 @@ +import io + + +class IterStream(io.RawIOBase): + """Wraps an iterator yielding bytes as a file object""" + + def __init__(self, iterator): + self.iterator = iterator + self.leftover = None + + def readable(self): + return True + + # Python 3 requires only .readinto() method, it still uses other ones + # under some circumstances and falls back if those are absent. Since + # iterator already constructs byte strings for us, .readinto() is not the + # most optimal, so we provide .read1() too. + + def readinto(self, b): + try: + n = len(b) # We're supposed to return at most this much + chunk = self.leftover or next(self.iterator) + output, self.leftover = chunk[:n], chunk[n:] + + n_out = len(output) + b[:n_out] = output + return n_out + except StopIteration: + return 0 # indicate EOF + + readinto1 = readinto + + def read1(self, n=-1): + try: + chunk = self.leftover or next(self.iterator) + except StopIteration: + return b"" + + # Return an arbitrary number or bytes + if n <= 0: + self.leftover = None + return chunk + + output, self.leftover = chunk[:n], chunk[n:] + return output diff --git a/setup.py b/setup.py index d60c928e01..87dad025b7 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] -gdrive = ["pydrive2>=1.4.13"] +gdrive = ["pydrive2>=1.4.14"] s3 = ["boto3>=1.9.201"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"] diff --git a/tests/func/test_api.py b/tests/func/test_api.py index 825ffa57a1..6e69d35831 100644 --- a/tests/func/test_api.py +++ b/tests/func/test_api.py @@ -8,9 +8,19 @@ from dvc.main import main from dvc.path_info import URLInfo from dvc.utils.fs import remove -from tests.remotes import GCP, HDFS, OSS, S3, SSH, Azure, Local - -remote_params = [S3, GCP, Azure, OSS, SSH, HDFS] +from tests.remotes import ( + GCP, + HDFS, + OSS, + S3, + SSH, + TEST_REMOTE, + Azure, + GDrive, + Local, +) + +remote_params = [S3, GCP, Azure, GDrive, OSS, SSH, HDFS] all_remote_params = [Local] + remote_params @@ -25,9 +35,48 @@ def run_dvc(*argv): assert main(argv) == 0 +def ensure_dir(dvc, url): + if url.startswith("gdrive://"): + GDrive.create_dir(dvc, url) + run_dvc( + "remote", + "modify", + TEST_REMOTE, + "gdrive_service_account_email", + "test", + ) + run_dvc( + "remote", + "modify", + TEST_REMOTE, + "gdrive_service_account_p12_file_path", + "test.p12", + ) + run_dvc( + "remote", + "modify", + TEST_REMOTE, + "gdrive_use_service_account", + "True", + ) + + +def ensure_dir_scm(dvc, url): + if url.startswith("gdrive://"): + GDrive.create_dir(dvc, url) + with dvc.config.edit() as conf: + conf["remote"][TEST_REMOTE].update( + gdrive_service_account_email="test", + gdrive_service_account_p12_file_path="test.p12", + gdrive_use_service_account=True, + ) + dvc.scm.add(dvc.config.files["repo"]) + dvc.scm.commit(f"modify '{TEST_REMOTE}' remote") + + @pytest.mark.parametrize("remote_url", remote_params, indirect=True) def test_get_url(tmp_dir, dvc, remote_url): - run_dvc("remote", "add", "-d", "upstream", remote_url) + run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) tmp_dir.dvc_gen("foo", "foo") expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8" @@ -58,7 +107,8 @@ def test_get_url_requires_dvc(tmp_dir, scm): @pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) def test_open(remote_url, tmp_dir, dvc): - run_dvc("remote", "add", "-d", "upstream", remote_url) + run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) + ensure_dir(dvc, remote_url) tmp_dir.dvc_gen("foo", "foo-text") run_dvc("push") @@ -72,6 +122,7 @@ def test_open(remote_url, tmp_dir, dvc): @pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) def test_open_external(remote_url, erepo_dir, setup_remote): setup_remote(erepo_dir.dvc, url=remote_url) + ensure_dir_scm(erepo_dir.dvc, remote_url) with erepo_dir.chdir(): erepo_dir.dvc_gen("version", "master", commit="add version") @@ -95,7 +146,8 @@ def test_open_external(remote_url, erepo_dir, setup_remote): @pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) def test_open_granular(remote_url, tmp_dir, dvc): - run_dvc("remote", "add", "-d", "upstream", remote_url) + run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) + ensure_dir(dvc, remote_url) tmp_dir.dvc_gen({"dir": {"foo": "foo-text"}}) run_dvc("push") @@ -109,7 +161,8 @@ def test_open_granular(remote_url, tmp_dir, dvc): @pytest.mark.parametrize("remote_url", all_remote_params, indirect=True) def test_missing(remote_url, tmp_dir, dvc): tmp_dir.dvc_gen("foo", "foo") - run_dvc("remote", "add", "-d", "upstream", remote_url) + run_dvc("remote", "add", "-d", TEST_REMOTE, remote_url) + ensure_dir(dvc, remote_url) # Remove cache to make foo missing remove(dvc.cache.local.cache_dir) diff --git a/tests/remotes.py b/tests/remotes.py index a5ce43a30a..cff71a173f 100644 --- a/tests/remotes.py +++ b/tests/remotes.py @@ -22,6 +22,7 @@ TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-temp") TEST_GCP_REPO_BUCKET = os.environ.get("DVC_TEST_GCP_REPO_BUCKET", "dvc-test") +TEST_GDRIVE_REPO_BUCKET = "root" TEST_OSS_REPO_BUCKET = "dvc-test" TEST_GCP_CREDS_FILE = os.path.abspath( @@ -152,10 +153,14 @@ def create_dir(dvc, url): remote = GDriveRemote(dvc, config) remote.tree._gdrive_create_dir("root", remote.path_info.path) + @staticmethod + def get_storagepath(): + return TEST_GDRIVE_REPO_BUCKET + "/" + str(uuid.uuid4()) + @staticmethod def get_url(): # NOTE: `get_url` should always return new random url - return "gdrive://root/" + str(uuid.uuid4()) + return "gdrive://" + GDrive.get_storagepath() class Azure: