Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions launch/cli/bin.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,11 @@
from dataclasses import dataclass
from typing import Optional

import click

from launch.cli.bundles import bundles
from launch.cli.config import ContextObject, config, set_config
from launch.cli.endpoints import endpoints


@dataclass
class ContextObject:
self_hosted: bool
gateway_endpoint: Optional[str] = None
api_key: Optional[str] = None


@click.group("cli")
@click.option(
"-s",
"--self-hosted",
is_flag=True,
help="Use this flag if Scale Launch is self hosted",
)
@click.option(
"-e",
"--gateway-endpoint",
envvar="LAUNCH_GATEWAY_ENDPOINT",
default=None,
type=str,
help="Redefine Scale Launch gateway endpoint. Mandatory parameter when using self-hosted Scale Launch.",
)
@click.option(
"-a",
"--api-key",
envvar="LAUNCH_API_KEY",
required=True,
type=str,
help="Scale Launch API key",
)
@click.pass_context
def entry_point(ctx, **kwargs):
"""
Expand All @@ -52,10 +21,13 @@ def entry_point(ctx, **kwargs):
╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝

"""
ctx.obj = ContextObject(**kwargs)
ctx.obj = ContextObject().load()
if ctx.obj.api_key is None:
ctx.invoke(set_config)


entry_point.add_command(bundles) # type: ignore
entry_point.add_command(config) # type: ignore
entry_point.add_command(endpoints) # type: ignore

if __name__ == "__main__":
Expand Down
20 changes: 3 additions & 17 deletions launch/cli/bundles.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def list_bundles(ctx: click.Context):

for model_bundle in client.list_model_bundles():
table.add_row(
model_bundle.bundle_id,
model_bundle.id,
model_bundle.name,
model_bundle.location,
model_bundle.packaging_type,
Expand All @@ -52,31 +52,17 @@ def get_bundle(ctx: click.Context, bundle_name: str):
model_bundle = client.get_model_bundle(bundle_name)

console = Console()
console.print(f"bundle_id: {model_bundle.bundle_id}")
console.print(f"bundle_id: {model_bundle.id}")
console.print(f"bundle_name: {model_bundle.name}")
console.print(f"location: {model_bundle.location}")
console.print(f"packaging_type: {model_bundle.packaging_type}")
console.print(f"env_params: {model_bundle.env_params}")
console.print(f"requirements: {model_bundle.requirements}")
console.print(f"app_config: {model_bundle.app_config}")

console.print("metadata:")
for meta_name, meta_value in model_bundle.metadata.items():
# TODO print non-code metadata differently
console.print(f"{meta_name}:", style="yellow")
syntax = Syntax(meta_value, "python")
console.print(syntax)


@bundles.command("delete")
@click.argument("bundle_name")
@click.pass_context
def delete_bundle(ctx: click.Context, bundle_name: str):
"""
Deletes a model bundle.
"""
client = init_client(ctx)

console = Console()
model_bundle = client.get_model_bundle(bundle_name)
res = client.delete_model_bundle(model_bundle)
console.print(res)
85 changes: 85 additions & 0 deletions launch/cli/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json
import os
from dataclasses import asdict, dataclass
from typing import Optional

import click
import questionary as q
from rich.console import Console
from rich.table import Table


@dataclass
class ContextObject:
self_hosted: Optional[bool] = False
gateway_endpoint: Optional[str] = None
api_key: Optional[str] = None

@staticmethod
def config_path():
config_dir = click.get_app_dir("launch")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return os.path.join(config_dir, "config.json")

def load(self):
try:
with open(self.config_path(), "r", encoding="utf-8") as f:
new_items = json.load(f)
for key, value in new_items.items():
if hasattr(self, key):
setattr(self, key, value)
except FileNotFoundError:
pass

return self

def save(self):
with open(self.config_path(), "w", encoding="utf-8") as f:
json.dump(asdict(self), f, indent=4)


@click.group("config")
@click.pass_context
def config(ctx: click.Context):
"""
Config is a wrapper around getting and setting your Scale Launch configuration
"""


@config.command("get")
@click.pass_context
def get_config(ctx: click.Context):
table = Table(
"Self-Hosted",
"API Key",
"Gateway Endpoint",
)

table.add_row(
str(ctx.obj.self_hosted), ctx.obj.api_key, ctx.obj.gateway_endpoint
)
console = Console()
console.print(table)


@config.command("set")
@click.pass_context
def set_config(ctx: click.Context):
ctx.obj.api_key = q.text(
message="Your Scale API Key?",
default=ctx.obj.api_key or "",
validate=lambda x: isinstance(x, str)
and len(x) > 16, # Arbitrary length right now
).ask()
ctx.obj.self_hosted = q.confirm(
message="Is your installation of Launch self-hosted?",
default=ctx.obj.self_hosted,
).ask()
if ctx.obj.self_hosted:
ctx.obj.gateway_endpoint = q.text(
message="Your Gateway Endpoint?",
default=ctx.obj.gateway_endpoint or "",
).ask()

ctx.obj.save()
8 changes: 4 additions & 4 deletions launch/cli/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def list_endpoints(ctx: click.Context):
servable_endpoint.model_endpoint.status,
servable_endpoint.model_endpoint.endpoint_type,
str(
(servable_endpoint.model_endpoint.worker_settings or {}).get(
(servable_endpoint.model_endpoint.deployment_state or {}).get(
"min_workers", ""
)
),
str(
(servable_endpoint.model_endpoint.worker_settings or {}).get(
(servable_endpoint.model_endpoint.deployment_state or {}).get(
"max_workers", ""
)
),
str(
(servable_endpoint.model_endpoint.worker_settings or {}).get(
(servable_endpoint.model_endpoint.deployment_state or {}).get(
"available_workers", ""
)
),
str(
(servable_endpoint.model_endpoint.worker_settings or {}).get(
(servable_endpoint.model_endpoint.deployment_state or {}).get(
"unavailable_workers", ""
)
),
Expand Down
78 changes: 77 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ click = "^8.0.0"
frozendict = "^2.3.4"
pydantic = "^1.10.4"
types-frozendict = "^2.0.9"
click-config-file = "^0.6.0"
questionary = "^1.10.0"

[tool.poetry.dev-dependencies]
black = "^22.1.0"
Expand Down