diff --git a/launch/cli/bin.py b/launch/cli/bin.py index 797db1b5..88101c96 100644 --- a/launch/cli/bin.py +++ b/launch/cli/bin.py @@ -1,42 +1,11 @@ -from dataclasses import dataclass -from typing import Optional - import click from launch.cli.bundles import bundles +from launch.cli.config import ContextObject, config, set_config from launch.cli.endpoints import endpoints -@dataclass -class ContextObject: - self_hosted: bool - gateway_endpoint: Optional[str] = None - api_key: Optional[str] = None - - @click.group("cli") -@click.option( - "-s", - "--self-hosted", - is_flag=True, - help="Use this flag if Scale Launch is self hosted", -) -@click.option( - "-e", - "--gateway-endpoint", - envvar="LAUNCH_GATEWAY_ENDPOINT", - default=None, - type=str, - help="Redefine Scale Launch gateway endpoint. Mandatory parameter when using self-hosted Scale Launch.", -) -@click.option( - "-a", - "--api-key", - envvar="LAUNCH_API_KEY", - required=True, - type=str, - help="Scale Launch API key", -) @click.pass_context def entry_point(ctx, **kwargs): """ @@ -52,10 +21,13 @@ def entry_point(ctx, **kwargs): ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝ """ - ctx.obj = ContextObject(**kwargs) + ctx.obj = ContextObject().load() + if ctx.obj.api_key is None: + ctx.invoke(set_config) entry_point.add_command(bundles) # type: ignore +entry_point.add_command(config) # type: ignore entry_point.add_command(endpoints) # type: ignore if __name__ == "__main__": diff --git a/launch/cli/bundles.py b/launch/cli/bundles.py index c457f368..cdfa8d96 100644 --- a/launch/cli/bundles.py +++ b/launch/cli/bundles.py @@ -33,7 +33,7 @@ def list_bundles(ctx: click.Context): for model_bundle in client.list_model_bundles(): table.add_row( - model_bundle.bundle_id, + model_bundle.id, model_bundle.name, model_bundle.location, model_bundle.packaging_type, @@ -52,12 +52,13 @@ def get_bundle(ctx: click.Context, bundle_name: str): model_bundle = client.get_model_bundle(bundle_name) console = Console() - console.print(f"bundle_id: {model_bundle.bundle_id}") + console.print(f"bundle_id: {model_bundle.id}") console.print(f"bundle_name: {model_bundle.name}") console.print(f"location: {model_bundle.location}") console.print(f"packaging_type: {model_bundle.packaging_type}") console.print(f"env_params: {model_bundle.env_params}") console.print(f"requirements: {model_bundle.requirements}") + console.print(f"app_config: {model_bundle.app_config}") console.print("metadata:") for meta_name, meta_value in model_bundle.metadata.items(): @@ -65,18 +66,3 @@ def get_bundle(ctx: click.Context, bundle_name: str): console.print(f"{meta_name}:", style="yellow") syntax = Syntax(meta_value, "python") console.print(syntax) - - -@bundles.command("delete") -@click.argument("bundle_name") -@click.pass_context -def delete_bundle(ctx: click.Context, bundle_name: str): - """ - Deletes a model bundle. - """ - client = init_client(ctx) - - console = Console() - model_bundle = client.get_model_bundle(bundle_name) - res = client.delete_model_bundle(model_bundle) - console.print(res) diff --git a/launch/cli/config.py b/launch/cli/config.py new file mode 100644 index 00000000..91495fbe --- /dev/null +++ b/launch/cli/config.py @@ -0,0 +1,85 @@ +import json +import os +from dataclasses import asdict, dataclass +from typing import Optional + +import click +import questionary as q +from rich.console import Console +from rich.table import Table + + +@dataclass +class ContextObject: + self_hosted: Optional[bool] = False + gateway_endpoint: Optional[str] = None + api_key: Optional[str] = None + + @staticmethod + def config_path(): + config_dir = click.get_app_dir("launch") + if not os.path.exists(config_dir): + os.makedirs(config_dir) + return os.path.join(config_dir, "config.json") + + def load(self): + try: + with open(self.config_path(), "r", encoding="utf-8") as f: + new_items = json.load(f) + for key, value in new_items.items(): + if hasattr(self, key): + setattr(self, key, value) + except FileNotFoundError: + pass + + return self + + def save(self): + with open(self.config_path(), "w", encoding="utf-8") as f: + json.dump(asdict(self), f, indent=4) + + +@click.group("config") +@click.pass_context +def config(ctx: click.Context): + """ + Config is a wrapper around getting and setting your Scale Launch configuration + """ + + +@config.command("get") +@click.pass_context +def get_config(ctx: click.Context): + table = Table( + "Self-Hosted", + "API Key", + "Gateway Endpoint", + ) + + table.add_row( + str(ctx.obj.self_hosted), ctx.obj.api_key, ctx.obj.gateway_endpoint + ) + console = Console() + console.print(table) + + +@config.command("set") +@click.pass_context +def set_config(ctx: click.Context): + ctx.obj.api_key = q.text( + message="Your Scale API Key?", + default=ctx.obj.api_key or "", + validate=lambda x: isinstance(x, str) + and len(x) > 16, # Arbitrary length right now + ).ask() + ctx.obj.self_hosted = q.confirm( + message="Is your installation of Launch self-hosted?", + default=ctx.obj.self_hosted, + ).ask() + if ctx.obj.self_hosted: + ctx.obj.gateway_endpoint = q.text( + message="Your Gateway Endpoint?", + default=ctx.obj.gateway_endpoint or "", + ).ask() + + ctx.obj.save() diff --git a/launch/cli/endpoints.py b/launch/cli/endpoints.py index 64487225..cca03359 100644 --- a/launch/cli/endpoints.py +++ b/launch/cli/endpoints.py @@ -20,6 +20,7 @@ def list_endpoints(ctx: click.Context): client = init_client(ctx) table = Table( + "Endpoint ID", "Endpoint name", "Bundle name", "Status", @@ -35,27 +36,28 @@ def list_endpoints(ctx: click.Context): for servable_endpoint in client.list_model_endpoints(): table.add_row( + servable_endpoint.model_endpoint.id, servable_endpoint.model_endpoint.name, servable_endpoint.model_endpoint.bundle_name, servable_endpoint.model_endpoint.status, servable_endpoint.model_endpoint.endpoint_type, str( - (servable_endpoint.model_endpoint.worker_settings or {}).get( + (servable_endpoint.model_endpoint.deployment_state or {}).get( "min_workers", "" ) ), str( - (servable_endpoint.model_endpoint.worker_settings or {}).get( + (servable_endpoint.model_endpoint.deployment_state or {}).get( "max_workers", "" ) ), str( - (servable_endpoint.model_endpoint.worker_settings or {}).get( + (servable_endpoint.model_endpoint.deployment_state or {}).get( "available_workers", "" ) ), str( - (servable_endpoint.model_endpoint.worker_settings or {}).get( + (servable_endpoint.model_endpoint.deployment_state or {}).get( "unavailable_workers", "" ) ), @@ -88,3 +90,25 @@ def read_endpoint_creation_logs(ctx: click.Context, endpoint_name: str): res = client.read_endpoint_creation_logs(endpoint_name) # rich fails to render the text because it's already formatted print(res) + + +@endpoints.command("get") +@click.argument("endpoint_name") +@click.pass_context +def get_endpoint(ctx: click.Context, endpoint_name: str): + """Print bundle info""" + client = init_client(ctx) + + model_endpoint = client.get_model_endpoint(endpoint_name).model_endpoint + + console = Console() + console.print(f"endpoint_id: {model_endpoint.id}") + console.print(f"endpoint_name: {model_endpoint.name}") + console.print(f"bundle_name: {model_endpoint.bundle_name}") + console.print(f"status: {model_endpoint.status}") + console.print(f"resource_state: {model_endpoint.resource_state}") + console.print(f"deployment_state: {model_endpoint.deployment_state}") + console.print(f"metadata: {model_endpoint.metadata}") + console.print(f"endpoint_type: {model_endpoint.endpoint_type}") + console.print(f"configs: {model_endpoint.configs}") + console.print(f"destination: {model_endpoint.destination}") diff --git a/poetry.lock b/poetry.lock index cffe4ae1..53ee8664 100644 --- a/poetry.lock +++ b/poetry.lock @@ -176,6 +176,22 @@ files = [ colorama = {version = "*", markers = "platform_system == \"Windows\""} importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} +[[package]] +name = "click-config-file" +version = "0.6.0" +description = "Configuration file support for click applications." +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "click_config_file-0.6.0-py2.py3-none-any.whl", hash = "sha256:3c5802dec437ed596f181efc988f62b1069cd48a912e280cd840ee70580f39d7"}, + {file = "click_config_file-0.6.0.tar.gz", hash = "sha256:ded6ec1a73c41280727ec9c06031e929cdd8a5946bf0f99c0c3db3a71793d515"}, +] + +[package.dependencies] +click = ">=6.7" +configobj = ">=5.0.6" + [[package]] name = "cloudpickle" version = "2.2.0" @@ -215,6 +231,21 @@ files = [ [package.extras] test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] +[[package]] +name = "configobj" +version = "5.0.8" +description = "Config file reading, writing and validation." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +files = [ + {file = "configobj-5.0.8-py2.py3-none-any.whl", hash = "sha256:a7a8c6ab7daade85c3f329931a807c8aee750a2494363934f8ea84d8a54c87ea"}, + {file = "configobj-5.0.8.tar.gz", hash = "sha256:6f704434a07dc4f4dc7c9a745172c1cad449feb548febd9f7fe362629c627a97"}, +] + +[package.dependencies] +six = "*" + [[package]] name = "coverage" version = "6.5.0" @@ -866,6 +897,21 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "prompt-toolkit" +version = "3.0.36" +description = "Library for building powerful interactive command lines in Python" +category = "main" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "prompt_toolkit-3.0.36-py3-none-any.whl", hash = "sha256:aa64ad242a462c5ff0363a7b9cfe696c20d55d9fc60c11fd8e632d064804d305"}, + {file = "prompt_toolkit-3.0.36.tar.gz", hash = "sha256:3e163f254bef5a03b146397d7c1963bd3e2812f0964bb9a24e6ec761fd28db63"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "pycodestyle" version = "2.8.0" @@ -1085,6 +1131,24 @@ files = [ {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, ] +[[package]] +name = "questionary" +version = "1.10.0" +description = "Python library to build pretty command line user prompts ⭐️" +category = "main" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "questionary-1.10.0-py3-none-any.whl", hash = "sha256:fecfcc8cca110fda9d561cb83f1e97ecbb93c613ff857f655818839dac74ce90"}, + {file = "questionary-1.10.0.tar.gz", hash = "sha256:600d3aefecce26d48d97eee936fdb66e4bc27f934c3ab6dd1e292c4f43946d90"}, +] + +[package.dependencies] +prompt_toolkit = ">=2.0,<4.0" + +[package.extras] +docs = ["Sphinx (>=3.3,<4.0)", "sphinx-autobuild (>=2020.9.1,<2021.0.0)", "sphinx-autodoc-typehints (>=1.11.1,<2.0.0)", "sphinx-copybutton (>=0.3.1,<0.4.0)", "sphinx-rtd-theme (>=0.5.0,<0.6.0)"] + [[package]] name = "requests" version = "2.28.1" @@ -1619,6 +1683,18 @@ platformdirs = ">=2,<3" docs = ["proselint (>=0.10.2)", "sphinx (>=3)", "sphinx-argparse (>=0.2.5)", "sphinx-rtd-theme (>=0.4.3)", "towncrier (>=21.3)"] testing = ["coverage (>=4)", "coverage-enable-subprocess (>=1)", "flaky (>=3)", "packaging (>=20.0)", "pytest (>=4)", "pytest-env (>=0.6.2)", "pytest-freezegun (>=0.4.1)", "pytest-mock (>=2)", "pytest-randomly (>=1)", "pytest-timeout (>=1)"] +[[package]] +name = "wcwidth" +version = "0.2.6" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.6-py2.py3-none-any.whl", hash = "sha256:795b138f6875577cd91bba52baf9e445cd5118fd32723b460e30a0af30ea230e"}, + {file = "wcwidth-0.2.6.tar.gz", hash = "sha256:a5220780a404dbe3353789870978e472cfe477761f06ee55077256e509b156d0"}, +] + [[package]] name = "wrapt" version = "1.14.1" @@ -1712,4 +1788,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "5321d7b785b15301a4dbabeb856c2887f8bc0e56d455f9956fb8f23781499ac7" +content-hash = "4089e9737dfe07e006a1fc328951944a8c36d425bddd70458c6fe174215c2c5f" diff --git a/pyproject.toml b/pyproject.toml index c0e2b24c..12a1ac17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,8 @@ click = "^8.0.0" frozendict = "^2.3.4" pydantic = "^1.10.4" types-frozendict = "^2.0.9" +click-config-file = "^0.6.0" +questionary = "^1.10.0" [tool.poetry.dev-dependencies] black = "^22.1.0"