diff --git a/dvc/command/remote.py b/dvc/command/remote.py index 0580537b87..fb562ad8c8 100644 --- a/dvc/command/remote.py +++ b/dvc/command/remote.py @@ -85,7 +85,18 @@ def run(self): if self.args.unset: conf["core"].pop("remote", None) else: - conf["core"]["remote"] = self.args.name + merged_conf = self.config.load_config_to_level( + self.args.level + ) + if ( + self.args.name in conf["remote"] + or self.args.name in merged_conf["remote"] + ): + conf["core"]["remote"] = self.args.name + else: + raise ConfigError( + "default remote must be present in remote list." + ) return 0 diff --git a/dvc/config.py b/dvc/config.py index 1fc63acbb3..e18dbe44db 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -278,7 +278,7 @@ def load(self, validate=True): Raises: ConfigError: thrown if config has an invalid format. """ - conf = self._load_config_to_level() + conf = self.load_config_to_level() if validate: conf = self.validate(conf) @@ -330,7 +330,7 @@ def _map_dirs(conf, func): dirs_schema = {"cache": {"dir": func}, "remote": {str: {"url": func}}} return Schema(dirs_schema, extra=ALLOW_EXTRA)(conf) - def _load_config_to_level(self, level=None): + def load_config_to_level(self, level=None): merged_conf = {} for merge_level in self.LEVELS: if merge_level == level: @@ -349,7 +349,7 @@ def edit(self, level="repo"): conf = self._save_paths(conf, self.files[level]) - merged_conf = self._load_config_to_level(level) + merged_conf = self.load_config_to_level(level) _merge(merged_conf, conf) self.validate(merged_conf) diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index 4369203dd7..ba122846fc 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -121,6 +121,7 @@ def test(self): def test_show_default(dvc, capsys): + assert main(["remote", "add", "foo", "s3://bucket/name"]) == 0 assert main(["remote", "default", "foo"]) == 0 assert main(["remote", "default"]) == 0 out, _ = capsys.readouterr() @@ -270,3 +271,21 @@ def test_remote_modify_validation(dvc): ) config = configobj.ConfigObj(dvc.config.files["repo"]) assert unsupported_config not in config['remote "{}"'.format(remote_name)] + + +def test_remote_modify_default(dvc): + remote_repo = "repo_level" + remote_local = "local_level" + wrong_name = "anything" + assert main(["remote", "add", remote_repo, "s3://bucket/repo"]) == 0 + assert main(["remote", "add", remote_local, "s3://bucket/local"]) == 0 + + assert main(["remote", "default", wrong_name]) == 251 + assert main(["remote", "default", remote_repo]) == 0 + assert main(["remote", "default", "--local", remote_local]) == 0 + + repo_config = configobj.ConfigObj(dvc.config.files["repo"]) + local_config = configobj.ConfigObj(dvc.config.files["local"]) + + assert repo_config["core"]["remote"] == remote_repo + assert local_config["core"]["remote"] == remote_local