Skip to content

Commit 0404a23

Browse files
authored
Merge pull request #2551 from MaxRis/gdrive-support
GDrive remote support
2 parents 94f803e + e722789 commit 0404a23

File tree

10 files changed

+441
-3
lines changed

10 files changed

+441
-3
lines changed

dvc/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes
110110
CONFIG = "config"
111111
CONFIG_LOCAL = "config.local"
112112

113+
CREDENTIALPATH = "credentialpath"
114+
113115
LEVEL_LOCAL = 0
114116
LEVEL_REPO = 1
115117
LEVEL_GLOBAL = 2
@@ -162,7 +164,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes
162164
}
163165

164166
# aws specific options
165-
SECTION_AWS_CREDENTIALPATH = "credentialpath"
167+
SECTION_AWS_CREDENTIALPATH = CREDENTIALPATH
166168
SECTION_AWS_ENDPOINT_URL = "endpointurl"
167169
SECTION_AWS_LIST_OBJECTS = "listobjects"
168170
SECTION_AWS_REGION = "region"
@@ -172,7 +174,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes
172174
SECTION_AWS_ACL = "acl"
173175

174176
# gcp specific options
175-
SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH
177+
SECTION_GCP_CREDENTIALPATH = CREDENTIALPATH
176178
SECTION_GCP_PROJECTNAME = "projectname"
177179

178180
# azure specific option
@@ -183,6 +185,11 @@ class Config(object): # pylint: disable=too-many-instance-attributes
183185
SECTION_OSS_ACCESS_KEY_SECRET = "oss_key_secret"
184186
SECTION_OSS_ENDPOINT = "oss_endpoint"
185187

188+
# GDrive options
189+
SECTION_GDRIVE_CLIENT_ID = "gdrive_client_id"
190+
SECTION_GDRIVE_CLIENT_SECRET = "gdrive_client_secret"
191+
SECTION_GDRIVE_USER_CREDENTIALS_FILE = "gdrive_user_credentials_file"
192+
186193
SECTION_REMOTE_REGEX = r'^\s*remote\s*"(?P<name>.*)"\s*$'
187194
SECTION_REMOTE_FMT = 'remote "{}"'
188195
SECTION_REMOTE_URL = "url"
@@ -218,6 +225,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes
218225
SECTION_OSS_ACCESS_KEY_ID: str,
219226
SECTION_OSS_ACCESS_KEY_SECRET: str,
220227
SECTION_OSS_ENDPOINT: str,
228+
SECTION_GDRIVE_CLIENT_ID: str,
229+
SECTION_GDRIVE_CLIENT_SECRET: str,
230+
SECTION_GDRIVE_USER_CREDENTIALS_FILE: str,
221231
PRIVATE_CWD: str,
222232
Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): Bool,
223233
}

dvc/remote/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .config import RemoteConfig
44
from dvc.remote.azure import RemoteAZURE
5+
from dvc.remote.gdrive import RemoteGDrive
56
from dvc.remote.gs import RemoteGS
67
from dvc.remote.hdfs import RemoteHDFS
78
from dvc.remote.http import RemoteHTTP
@@ -14,6 +15,7 @@
1415

1516
REMOTES = [
1617
RemoteAZURE,
18+
RemoteGDrive,
1719
RemoteGS,
1820
RemoteHDFS,
1921
RemoteHTTP,

dvc/remote/gdrive/__init__.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from __future__ import unicode_literals
2+
3+
import os
4+
import posixpath
5+
import logging
6+
import threading
7+
8+
from funcy import retry, compose, decorator, wrap_with
9+
from funcy.py3 import cat
10+
11+
from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE
12+
from dvc.scheme import Schemes
13+
from dvc.path_info import CloudURLInfo
14+
from dvc.remote.base import RemoteBASE
15+
from dvc.config import Config
16+
from dvc.exceptions import DvcException
17+
from dvc.utils import tmp_fname
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class GDriveRetriableError(DvcException):
23+
def __init__(self, msg):
24+
super(GDriveRetriableError, self).__init__(msg)
25+
26+
27+
@decorator
28+
def _wrap_pydrive_retriable(call):
29+
from apiclient import errors
30+
from pydrive.files import ApiRequestError
31+
32+
try:
33+
result = call()
34+
except (ApiRequestError, errors.HttpError) as exception:
35+
retry_codes = ["403", "500", "502", "503", "504"]
36+
if any(
37+
"HttpError {}".format(code) in str(exception)
38+
for code in retry_codes
39+
):
40+
raise GDriveRetriableError(msg="Google API request failed")
41+
raise
42+
return result
43+
44+
45+
gdrive_retry = compose(
46+
# 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s
47+
retry(
48+
8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10)
49+
),
50+
_wrap_pydrive_retriable,
51+
)
52+
53+
54+
class RemoteGDrive(RemoteBASE):
55+
scheme = Schemes.GDRIVE
56+
path_cls = CloudURLInfo
57+
REQUIRES = {"pydrive": "pydrive"}
58+
GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA"
59+
DEFAULT_USER_CREDENTIALS_FILE = ".dvc/tmp/gdrive-user-credentials.json"
60+
61+
def __init__(self, repo, config):
62+
super(RemoteGDrive, self).__init__(repo, config)
63+
self.no_traverse = False
64+
self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL])
65+
self.config = config
66+
self.init_drive()
67+
68+
def init_drive(self):
69+
self.client_id = self.config.get(Config.SECTION_GDRIVE_CLIENT_ID, None)
70+
self.client_secret = self.config.get(
71+
Config.SECTION_GDRIVE_CLIENT_SECRET, None
72+
)
73+
if not self.client_id or not self.client_secret:
74+
raise DvcException(
75+
"Please specify Google Drive's client id and "
76+
"secret in DVC's config. Learn more at "
77+
"https://man.dvc.org/remote/add."
78+
)
79+
self.gdrive_user_credentials_path = (
80+
tmp_fname(".dvc/tmp/")
81+
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
82+
else self.config.get(
83+
Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE,
84+
self.DEFAULT_USER_CREDENTIALS_FILE,
85+
)
86+
)
87+
88+
@gdrive_retry
89+
def gdrive_upload_file(
90+
self, args, no_progress_bar=True, from_file="", progress_name=""
91+
):
92+
item = self.drive.CreateFile(
93+
{"title": args["title"], "parents": [{"id": args["parent_id"]}]}
94+
)
95+
self.upload_file(item, no_progress_bar, from_file, progress_name)
96+
return item
97+
98+
def upload_file(self, item, no_progress_bar, from_file, progress_name):
99+
with open(from_file, "rb") as opened_file:
100+
if not no_progress_bar:
101+
opened_file = TrackFileReadProgress(progress_name, opened_file)
102+
# PyDrive doesn't like content property setting for empty files
103+
# https://github.com/gsuitedevs/PyDrive/issues/121
104+
if os.stat(from_file).st_size:
105+
item.content = opened_file
106+
item.Upload()
107+
108+
@gdrive_retry
109+
def gdrive_download_file(
110+
self, file_id, to_file, progress_name, no_progress_bar
111+
):
112+
from dvc.progress import Tqdm
113+
114+
gdrive_file = self.drive.CreateFile({"id": file_id})
115+
with Tqdm(
116+
desc=progress_name,
117+
total=int(gdrive_file["fileSize"]),
118+
disable=no_progress_bar,
119+
):
120+
gdrive_file.GetContentFile(to_file)
121+
122+
def gdrive_list_item(self, query):
123+
file_list = self.drive.ListFile({"q": query, "maxResults": 1000})
124+
125+
# Isolate and decorate fetching of remote drive items in pages
126+
get_list = gdrive_retry(lambda: next(file_list, None))
127+
128+
# Fetch pages until None is received, lazily flatten the thing
129+
return cat(iter(get_list, None))
130+
131+
def cache_root_dirs(self):
132+
cached_dirs = {}
133+
cached_ids = {}
134+
for dir1 in self.gdrive_list_item(
135+
"'{}' in parents and trashed=false".format(self.root_id)
136+
):
137+
remote_path = posixpath.join(self.path_info.path, dir1["title"])
138+
cached_dirs.setdefault(remote_path, []).append(dir1["id"])
139+
cached_ids[dir1["id"]] = dir1["title"]
140+
return cached_dirs, cached_ids
141+
142+
@property
143+
def cached_dirs(self):
144+
if not hasattr(self, "_cached_dirs"):
145+
self.drive
146+
return self._cached_dirs
147+
148+
@property
149+
def cached_ids(self):
150+
if not hasattr(self, "_cached_ids"):
151+
self.drive
152+
return self._cached_ids
153+
154+
@property
155+
@wrap_with(threading.RLock())
156+
def drive(self):
157+
if not hasattr(self, "_gdrive"):
158+
from pydrive.auth import GoogleAuth
159+
from pydrive.drive import GoogleDrive
160+
161+
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
162+
with open(
163+
self.gdrive_user_credentials_path, "w"
164+
) as credentials_file:
165+
credentials_file.write(
166+
os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
167+
)
168+
169+
GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings"
170+
GoogleAuth.DEFAULT_SETTINGS["client_config"] = {
171+
"client_id": self.client_id,
172+
"client_secret": self.client_secret,
173+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
174+
"token_uri": "https://oauth2.googleapis.com/token",
175+
"revoke_uri": "https://oauth2.googleapis.com/revoke",
176+
"redirect_uri": "",
177+
}
178+
GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True
179+
GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file"
180+
GoogleAuth.DEFAULT_SETTINGS[
181+
"save_credentials_file"
182+
] = self.gdrive_user_credentials_path
183+
GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True
184+
GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [
185+
"https://www.googleapis.com/auth/drive",
186+
"https://www.googleapis.com/auth/drive.appdata",
187+
]
188+
189+
# Pass non existent settings path to force DEFAULT_SETTINGS loading
190+
gauth = GoogleAuth(settings_file="")
191+
gauth.CommandLineAuth()
192+
193+
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
194+
os.remove(self.gdrive_user_credentials_path)
195+
196+
self._gdrive = GoogleDrive(gauth)
197+
198+
self.root_id = self.get_remote_id(self.path_info, create=True)
199+
self._cached_dirs, self._cached_ids = self.cache_root_dirs()
200+
201+
return self._gdrive
202+
203+
@gdrive_retry
204+
def create_remote_dir(self, parent_id, title):
205+
item = self.drive.CreateFile(
206+
{
207+
"title": title,
208+
"parents": [{"id": parent_id}],
209+
"mimeType": FOLDER_MIME_TYPE,
210+
}
211+
)
212+
item.Upload()
213+
return item
214+
215+
@gdrive_retry
216+
def get_remote_item(self, name, parents_ids):
217+
if not parents_ids:
218+
return None
219+
query = " or ".join(
220+
"'{}' in parents".format(parent_id) for parent_id in parents_ids
221+
)
222+
223+
query += " and trashed=false and title='{}'".format(name)
224+
225+
# Limit found remote items count to 1 in response
226+
item_list = self.drive.ListFile(
227+
{"q": query, "maxResults": 1}
228+
).GetList()
229+
return next(iter(item_list), None)
230+
231+
def resolve_remote_item_from_path(self, path_parts, create):
232+
parents_ids = ["root"]
233+
current_path = ""
234+
for path_part in path_parts:
235+
current_path = posixpath.join(current_path, path_part)
236+
remote_ids = self.get_remote_id_from_cache(current_path)
237+
if remote_ids:
238+
parents_ids = remote_ids
239+
continue
240+
item = self.get_remote_item(path_part, parents_ids)
241+
if not item and create:
242+
item = self.create_remote_dir(parents_ids[0], path_part)
243+
elif not item:
244+
return None
245+
parents_ids = [item["id"]]
246+
return item
247+
248+
def get_remote_id_from_cache(self, remote_path):
249+
if hasattr(self, "_cached_dirs"):
250+
return self.cached_dirs.get(remote_path, [])
251+
return []
252+
253+
def get_remote_id(self, path_info, create=False):
254+
remote_ids = self.get_remote_id_from_cache(path_info.path)
255+
256+
if remote_ids:
257+
return remote_ids[0]
258+
259+
file1 = self.resolve_remote_item_from_path(
260+
path_info.path.split("/"), create
261+
)
262+
return file1["id"] if file1 else ""
263+
264+
def exists(self, path_info):
265+
return self.get_remote_id(path_info) != ""
266+
267+
def _upload(self, from_file, to_info, name, no_progress_bar):
268+
dirname = to_info.parent
269+
if dirname:
270+
parent_id = self.get_remote_id(dirname, True)
271+
else:
272+
parent_id = to_info.bucket
273+
274+
self.gdrive_upload_file(
275+
{"title": to_info.name, "parent_id": parent_id},
276+
no_progress_bar,
277+
from_file,
278+
name,
279+
)
280+
281+
def _download(self, from_info, to_file, name, no_progress_bar):
282+
file_id = self.get_remote_id(from_info)
283+
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)
284+
285+
def all(self):
286+
if not self.cached_ids:
287+
return
288+
289+
query = " or ".join(
290+
"'{}' in parents".format(dir_id) for dir_id in self.cached_ids
291+
)
292+
293+
query += " and trashed=false"
294+
for file1 in self.gdrive_list_item(query):
295+
parent_id = file1["parents"][0]["id"]
296+
path = posixpath.join(self.cached_ids[parent_id], file1["title"])
297+
try:
298+
yield self.path_to_checksum(path)
299+
except ValueError:
300+
# We ignore all the non-cache looking files
301+
logger.debug('Ignoring path as "non-cache looking"')

dvc/remote/gdrive/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
3+
from dvc.progress import Tqdm
4+
5+
6+
FOLDER_MIME_TYPE = "application/vnd.google-apps.folder"
7+
8+
9+
class TrackFileReadProgress(object):
10+
def __init__(self, progress_name, fobj):
11+
self.progress_name = progress_name
12+
self.fobj = fobj
13+
file_size = os.fstat(fobj.fileno()).st_size
14+
self.tqdm = Tqdm(desc=self.progress_name, total=file_size)
15+
16+
def read(self, size):
17+
self.tqdm.update(size)
18+
return self.fobj.read(size)
19+
20+
def close(self):
21+
self.fobj.close()
22+
self.tqdm.close()
23+
24+
def __getattr__(self, attr):
25+
return getattr(self.fobj, attr)

dvc/scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ class Schemes:
99
HTTP = "http"
1010
HTTPS = "https"
1111
GS = "gs"
12+
GDRIVE = "gdrive"
1213
LOCAL = "local"
1314
OSS = "oss"

0 commit comments

Comments
 (0)