|
| 1 | +from __future__ import unicode_literals |
| 2 | + |
| 3 | +import os |
| 4 | +import posixpath |
| 5 | +import logging |
| 6 | +import threading |
| 7 | + |
| 8 | +from funcy import retry, compose, decorator, wrap_with |
| 9 | +from funcy.py3 import cat |
| 10 | + |
| 11 | +from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE |
| 12 | +from dvc.scheme import Schemes |
| 13 | +from dvc.path_info import CloudURLInfo |
| 14 | +from dvc.remote.base import RemoteBASE |
| 15 | +from dvc.config import Config |
| 16 | +from dvc.exceptions import DvcException |
| 17 | +from dvc.utils import tmp_fname |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +class GDriveRetriableError(DvcException): |
| 23 | + def __init__(self, msg): |
| 24 | + super(GDriveRetriableError, self).__init__(msg) |
| 25 | + |
| 26 | + |
| 27 | +@decorator |
| 28 | +def _wrap_pydrive_retriable(call): |
| 29 | + from apiclient import errors |
| 30 | + from pydrive.files import ApiRequestError |
| 31 | + |
| 32 | + try: |
| 33 | + result = call() |
| 34 | + except (ApiRequestError, errors.HttpError) as exception: |
| 35 | + retry_codes = ["403", "500", "502", "503", "504"] |
| 36 | + if any( |
| 37 | + "HttpError {}".format(code) in str(exception) |
| 38 | + for code in retry_codes |
| 39 | + ): |
| 40 | + raise GDriveRetriableError(msg="Google API request failed") |
| 41 | + raise |
| 42 | + return result |
| 43 | + |
| 44 | + |
| 45 | +gdrive_retry = compose( |
| 46 | + # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s |
| 47 | + retry( |
| 48 | + 8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10) |
| 49 | + ), |
| 50 | + _wrap_pydrive_retriable, |
| 51 | +) |
| 52 | + |
| 53 | + |
| 54 | +class RemoteGDrive(RemoteBASE): |
| 55 | + scheme = Schemes.GDRIVE |
| 56 | + path_cls = CloudURLInfo |
| 57 | + REQUIRES = {"pydrive": "pydrive"} |
| 58 | + GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA" |
| 59 | + DEFAULT_USER_CREDENTIALS_FILE = ".dvc/tmp/gdrive-user-credentials.json" |
| 60 | + |
| 61 | + def __init__(self, repo, config): |
| 62 | + super(RemoteGDrive, self).__init__(repo, config) |
| 63 | + self.no_traverse = False |
| 64 | + self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) |
| 65 | + self.config = config |
| 66 | + self.init_drive() |
| 67 | + |
| 68 | + def init_drive(self): |
| 69 | + self.client_id = self.config.get(Config.SECTION_GDRIVE_CLIENT_ID, None) |
| 70 | + self.client_secret = self.config.get( |
| 71 | + Config.SECTION_GDRIVE_CLIENT_SECRET, None |
| 72 | + ) |
| 73 | + if not self.client_id or not self.client_secret: |
| 74 | + raise DvcException( |
| 75 | + "Please specify Google Drive's client id and " |
| 76 | + "secret in DVC's config. Learn more at " |
| 77 | + "https://man.dvc.org/remote/add." |
| 78 | + ) |
| 79 | + self.gdrive_user_credentials_path = ( |
| 80 | + tmp_fname(".dvc/tmp/") |
| 81 | + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) |
| 82 | + else self.config.get( |
| 83 | + Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, |
| 84 | + self.DEFAULT_USER_CREDENTIALS_FILE, |
| 85 | + ) |
| 86 | + ) |
| 87 | + |
| 88 | + @gdrive_retry |
| 89 | + def gdrive_upload_file( |
| 90 | + self, args, no_progress_bar=True, from_file="", progress_name="" |
| 91 | + ): |
| 92 | + item = self.drive.CreateFile( |
| 93 | + {"title": args["title"], "parents": [{"id": args["parent_id"]}]} |
| 94 | + ) |
| 95 | + self.upload_file(item, no_progress_bar, from_file, progress_name) |
| 96 | + return item |
| 97 | + |
| 98 | + def upload_file(self, item, no_progress_bar, from_file, progress_name): |
| 99 | + with open(from_file, "rb") as opened_file: |
| 100 | + if not no_progress_bar: |
| 101 | + opened_file = TrackFileReadProgress(progress_name, opened_file) |
| 102 | + # PyDrive doesn't like content property setting for empty files |
| 103 | + # https://github.com/gsuitedevs/PyDrive/issues/121 |
| 104 | + if os.stat(from_file).st_size: |
| 105 | + item.content = opened_file |
| 106 | + item.Upload() |
| 107 | + |
| 108 | + @gdrive_retry |
| 109 | + def gdrive_download_file( |
| 110 | + self, file_id, to_file, progress_name, no_progress_bar |
| 111 | + ): |
| 112 | + from dvc.progress import Tqdm |
| 113 | + |
| 114 | + gdrive_file = self.drive.CreateFile({"id": file_id}) |
| 115 | + with Tqdm( |
| 116 | + desc=progress_name, |
| 117 | + total=int(gdrive_file["fileSize"]), |
| 118 | + disable=no_progress_bar, |
| 119 | + ): |
| 120 | + gdrive_file.GetContentFile(to_file) |
| 121 | + |
| 122 | + def gdrive_list_item(self, query): |
| 123 | + file_list = self.drive.ListFile({"q": query, "maxResults": 1000}) |
| 124 | + |
| 125 | + # Isolate and decorate fetching of remote drive items in pages |
| 126 | + get_list = gdrive_retry(lambda: next(file_list, None)) |
| 127 | + |
| 128 | + # Fetch pages until None is received, lazily flatten the thing |
| 129 | + return cat(iter(get_list, None)) |
| 130 | + |
| 131 | + def cache_root_dirs(self): |
| 132 | + cached_dirs = {} |
| 133 | + cached_ids = {} |
| 134 | + for dir1 in self.gdrive_list_item( |
| 135 | + "'{}' in parents and trashed=false".format(self.root_id) |
| 136 | + ): |
| 137 | + remote_path = posixpath.join(self.path_info.path, dir1["title"]) |
| 138 | + cached_dirs.setdefault(remote_path, []).append(dir1["id"]) |
| 139 | + cached_ids[dir1["id"]] = dir1["title"] |
| 140 | + return cached_dirs, cached_ids |
| 141 | + |
| 142 | + @property |
| 143 | + def cached_dirs(self): |
| 144 | + if not hasattr(self, "_cached_dirs"): |
| 145 | + self.drive |
| 146 | + return self._cached_dirs |
| 147 | + |
| 148 | + @property |
| 149 | + def cached_ids(self): |
| 150 | + if not hasattr(self, "_cached_ids"): |
| 151 | + self.drive |
| 152 | + return self._cached_ids |
| 153 | + |
| 154 | + @property |
| 155 | + @wrap_with(threading.RLock()) |
| 156 | + def drive(self): |
| 157 | + if not hasattr(self, "_gdrive"): |
| 158 | + from pydrive.auth import GoogleAuth |
| 159 | + from pydrive.drive import GoogleDrive |
| 160 | + |
| 161 | + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): |
| 162 | + with open( |
| 163 | + self.gdrive_user_credentials_path, "w" |
| 164 | + ) as credentials_file: |
| 165 | + credentials_file.write( |
| 166 | + os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) |
| 167 | + ) |
| 168 | + |
| 169 | + GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" |
| 170 | + GoogleAuth.DEFAULT_SETTINGS["client_config"] = { |
| 171 | + "client_id": self.client_id, |
| 172 | + "client_secret": self.client_secret, |
| 173 | + "auth_uri": "https://accounts.google.com/o/oauth2/auth", |
| 174 | + "token_uri": "https://oauth2.googleapis.com/token", |
| 175 | + "revoke_uri": "https://oauth2.googleapis.com/revoke", |
| 176 | + "redirect_uri": "", |
| 177 | + } |
| 178 | + GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True |
| 179 | + GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" |
| 180 | + GoogleAuth.DEFAULT_SETTINGS[ |
| 181 | + "save_credentials_file" |
| 182 | + ] = self.gdrive_user_credentials_path |
| 183 | + GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True |
| 184 | + GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ |
| 185 | + "https://www.googleapis.com/auth/drive", |
| 186 | + "https://www.googleapis.com/auth/drive.appdata", |
| 187 | + ] |
| 188 | + |
| 189 | + # Pass non existent settings path to force DEFAULT_SETTINGS loading |
| 190 | + gauth = GoogleAuth(settings_file="") |
| 191 | + gauth.CommandLineAuth() |
| 192 | + |
| 193 | + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): |
| 194 | + os.remove(self.gdrive_user_credentials_path) |
| 195 | + |
| 196 | + self._gdrive = GoogleDrive(gauth) |
| 197 | + |
| 198 | + self.root_id = self.get_remote_id(self.path_info, create=True) |
| 199 | + self._cached_dirs, self._cached_ids = self.cache_root_dirs() |
| 200 | + |
| 201 | + return self._gdrive |
| 202 | + |
| 203 | + @gdrive_retry |
| 204 | + def create_remote_dir(self, parent_id, title): |
| 205 | + item = self.drive.CreateFile( |
| 206 | + { |
| 207 | + "title": title, |
| 208 | + "parents": [{"id": parent_id}], |
| 209 | + "mimeType": FOLDER_MIME_TYPE, |
| 210 | + } |
| 211 | + ) |
| 212 | + item.Upload() |
| 213 | + return item |
| 214 | + |
| 215 | + @gdrive_retry |
| 216 | + def get_remote_item(self, name, parents_ids): |
| 217 | + if not parents_ids: |
| 218 | + return None |
| 219 | + query = " or ".join( |
| 220 | + "'{}' in parents".format(parent_id) for parent_id in parents_ids |
| 221 | + ) |
| 222 | + |
| 223 | + query += " and trashed=false and title='{}'".format(name) |
| 224 | + |
| 225 | + # Limit found remote items count to 1 in response |
| 226 | + item_list = self.drive.ListFile( |
| 227 | + {"q": query, "maxResults": 1} |
| 228 | + ).GetList() |
| 229 | + return next(iter(item_list), None) |
| 230 | + |
| 231 | + def resolve_remote_item_from_path(self, path_parts, create): |
| 232 | + parents_ids = ["root"] |
| 233 | + current_path = "" |
| 234 | + for path_part in path_parts: |
| 235 | + current_path = posixpath.join(current_path, path_part) |
| 236 | + remote_ids = self.get_remote_id_from_cache(current_path) |
| 237 | + if remote_ids: |
| 238 | + parents_ids = remote_ids |
| 239 | + continue |
| 240 | + item = self.get_remote_item(path_part, parents_ids) |
| 241 | + if not item and create: |
| 242 | + item = self.create_remote_dir(parents_ids[0], path_part) |
| 243 | + elif not item: |
| 244 | + return None |
| 245 | + parents_ids = [item["id"]] |
| 246 | + return item |
| 247 | + |
| 248 | + def get_remote_id_from_cache(self, remote_path): |
| 249 | + if hasattr(self, "_cached_dirs"): |
| 250 | + return self.cached_dirs.get(remote_path, []) |
| 251 | + return [] |
| 252 | + |
| 253 | + def get_remote_id(self, path_info, create=False): |
| 254 | + remote_ids = self.get_remote_id_from_cache(path_info.path) |
| 255 | + |
| 256 | + if remote_ids: |
| 257 | + return remote_ids[0] |
| 258 | + |
| 259 | + file1 = self.resolve_remote_item_from_path( |
| 260 | + path_info.path.split("/"), create |
| 261 | + ) |
| 262 | + return file1["id"] if file1 else "" |
| 263 | + |
| 264 | + def exists(self, path_info): |
| 265 | + return self.get_remote_id(path_info) != "" |
| 266 | + |
| 267 | + def _upload(self, from_file, to_info, name, no_progress_bar): |
| 268 | + dirname = to_info.parent |
| 269 | + if dirname: |
| 270 | + parent_id = self.get_remote_id(dirname, True) |
| 271 | + else: |
| 272 | + parent_id = to_info.bucket |
| 273 | + |
| 274 | + self.gdrive_upload_file( |
| 275 | + {"title": to_info.name, "parent_id": parent_id}, |
| 276 | + no_progress_bar, |
| 277 | + from_file, |
| 278 | + name, |
| 279 | + ) |
| 280 | + |
| 281 | + def _download(self, from_info, to_file, name, no_progress_bar): |
| 282 | + file_id = self.get_remote_id(from_info) |
| 283 | + self.gdrive_download_file(file_id, to_file, name, no_progress_bar) |
| 284 | + |
| 285 | + def all(self): |
| 286 | + if not self.cached_ids: |
| 287 | + return |
| 288 | + |
| 289 | + query = " or ".join( |
| 290 | + "'{}' in parents".format(dir_id) for dir_id in self.cached_ids |
| 291 | + ) |
| 292 | + |
| 293 | + query += " and trashed=false" |
| 294 | + for file1 in self.gdrive_list_item(query): |
| 295 | + parent_id = file1["parents"][0]["id"] |
| 296 | + path = posixpath.join(self.cached_ids[parent_id], file1["title"]) |
| 297 | + try: |
| 298 | + yield self.path_to_checksum(path) |
| 299 | + except ValueError: |
| 300 | + # We ignore all the non-cache looking files |
| 301 | + logger.debug('Ignoring path as "non-cache looking"') |
0 commit comments