|
80 | 80 | ]
|
81 | 81 |
|
82 | 82 |
|
| 83 | +def _find_parser(parser, cmd_cls): |
| 84 | + defaults = parser._defaults # pylint: disable=protected-access |
| 85 | + if not cmd_cls or cmd_cls == defaults.get("func"): |
| 86 | + parser.print_help() |
| 87 | + raise DvcParserError() |
| 88 | + |
| 89 | + actions = parser._actions # pylint: disable=protected-access |
| 90 | + for action in actions: |
| 91 | + if not isinstance(action.choices, dict): |
| 92 | + # NOTE: we are only interested in subparsers |
| 93 | + continue |
| 94 | + for subparser in action.choices.values(): |
| 95 | + _find_parser(subparser, cmd_cls) |
| 96 | + |
| 97 | + |
83 | 98 | class DvcParser(argparse.ArgumentParser):
|
84 | 99 | """Custom parser class for dvc CLI."""
|
85 | 100 |
|
86 |
| - def error(self, message, command=None): # pylint: disable=arguments-differ |
87 |
| - """Custom error method. |
88 |
| - Args: |
89 |
| - message (str): error message. |
90 |
| - command (str): subcommand name for help message |
91 |
| - Raises: |
92 |
| - dvc.exceptions.DvcParser: dvc parser exception. |
93 |
| -
|
94 |
| - """ |
| 101 | + def error(self, message, cmd_cls=None): # pylint: disable=arguments-differ |
95 | 102 | logger.error(message)
|
96 |
| - if command is not None: |
97 |
| - for action in self._actions: |
98 |
| - if action.dest == "cmd" and command in action.choices: |
99 |
| - subparser = action.choices[command] |
100 |
| - subparser.print_help() |
101 |
| - raise DvcParserError() |
102 |
| - self.print_help() |
103 |
| - raise DvcParserError() |
| 103 | + _find_parser(self, cmd_cls) |
104 | 104 |
|
105 |
| - # override this to send subcommand name to error method |
106 | 105 | def parse_args(self, args=None, namespace=None):
|
| 106 | + # NOTE: overriding to provide a more granular help message. |
| 107 | + # E.g. `dvc plots diff --bad-flag` would result in a `dvc plots diff` |
| 108 | + # help message instead of generic `dvc` usage. |
107 | 109 | args, argv = self.parse_known_args(args, namespace)
|
108 | 110 | if argv:
|
109 | 111 | msg = "unrecognized arguments: %s"
|
110 |
| - self.error(msg % " ".join(argv), args.cmd) |
| 112 | + self.error(msg % " ".join(argv), getattr(args, "func", None)) |
111 | 113 | return args
|
112 | 114 |
|
113 | 115 |
|
|
0 commit comments