Skip to content

Commit 049d219

Browse files
[Provisioner] New provisioner for GCP TPU VM (#2898)
* init * test * test ins_type * fix * format.. * wip * remove TPU config * fix node ips * Fix TPU VM pod * format * use TPU VM as default * Fix example for TPU VM * format * fix optimizer random dag * set TPU-VM * accelerator_args False * backward compatibility * add tpu filter for tests * fix * Fix * fix status refresh for tpu VM pod * Support autodown for TPU VM pod * Allow multi-node TPU VM pod * Allow multi-node TPU VM pod * fix * add execute for operation * avoid from * Wait for pending before set_labels * format * refactor constants * Fix for API changes * remove GCP failover handler v1 * format * remove TPU VM pod specific codes as they have been moved to new provisioner * Add error handling for TPU pod case * fix * fix multiple node calculation * refactor tpu_utils to gcp_utils * shorter time for recovering * format --------- Co-authored-by: Wei-Lin Chiang <[email protected]>
1 parent 27a295d commit 049d219

16 files changed

+446
-636
lines changed

sky/backends/backend_utils.py

+19-106
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Util constants/functions for the backends."""
22
from datetime import datetime
33
import enum
4-
import json
54
import os
65
import pathlib
76
import pprint
@@ -40,9 +39,9 @@
4039
from sky import status_lib
4140
from sky.backends import onprem_utils
4241
from sky.clouds import cloud_registry
42+
from sky.clouds.utils import gcp_utils
4343
from sky.provision import instance_setup
4444
from sky.skylet import constants
45-
from sky.skylet import log_lib
4645
from sky.usage import usage_lib
4746
from sky.utils import cluster_yaml_utils
4847
from sky.utils import command_runner
@@ -52,7 +51,6 @@
5251
from sky.utils import rich_utils
5352
from sky.utils import subprocess_utils
5453
from sky.utils import timeline
55-
from sky.utils import tpu_utils
5654
from sky.utils import ux_utils
5755

5856
if typing.TYPE_CHECKING:
@@ -1107,7 +1105,7 @@ def write_cluster_config(
11071105
usage_lib.messages.usage.update_ray_yaml(yaml_path)
11081106

11091107
# For TPU nodes. TPU VMs do not need TPU_NAME.
1110-
if tpu_utils.is_tpu(to_provision) and not tpu_utils.is_tpu_vm(to_provision):
1108+
if gcp_utils.is_tpu(to_provision) and not gcp_utils.is_tpu_vm(to_provision):
11111109
tpu_name = resources_vars.get('tpu_name')
11121110
if tpu_name is None:
11131111
tpu_name = cluster_name
@@ -1528,19 +1526,22 @@ def _query_head_ip_with_retries(cluster_yaml: str,
15281526

15291527

15301528
@timeline.event
1531-
def get_node_ips(cluster_yaml: str,
1532-
expected_num_nodes: int,
1533-
handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle',
1534-
head_ip_max_attempts: int = 1,
1535-
worker_ip_max_attempts: int = 1,
1536-
get_internal_ips: bool = False) -> List[str]:
1529+
def get_node_ips(
1530+
cluster_yaml: str,
1531+
expected_num_nodes: int,
1532+
# TODO: remove this argument once we remove the legacy on-prem
1533+
# support.
1534+
handle: 'cloud_vm_ray_backend.CloudVmRayResourceHandle',
1535+
head_ip_max_attempts: int = 1,
1536+
worker_ip_max_attempts: int = 1,
1537+
get_internal_ips: bool = False) -> List[str]:
15371538
"""Returns the IPs of all nodes in the cluster, with head node at front.
15381539
15391540
Args:
15401541
cluster_yaml: Path to the cluster yaml.
15411542
expected_num_nodes: Expected number of nodes in the cluster.
1542-
handle: Cloud VM Ray resource handle. It is only required for TPU VM or
1543-
on-prem clusters.
1543+
handle: Cloud VM Ray resource handle. It is only required for on-prem
1544+
clusters.
15441545
head_ip_max_attempts: Max attempts to get head ip.
15451546
worker_ip_max_attempts: Max attempts to get worker ips.
15461547
get_internal_ips: Whether to get internal IPs. When False, it is still
@@ -1551,18 +1552,13 @@ def get_node_ips(cluster_yaml: str,
15511552
exceptions.FetchIPError: if we failed to get the IPs. e.reason is
15521553
HEAD or WORKER.
15531554
"""
1554-
# When ray up launches TPU VM Pod, Pod workers (except for the head)
1555-
# won't be connected to Ray cluster. Thus "ray get-worker-ips"
1556-
# won't work and we need to query the node IPs with gcloud as
1557-
# implmented in _get_tpu_vm_pod_ips.
15581555
ray_config = common_utils.read_yaml(cluster_yaml)
15591556
# Use the new provisioner for AWS.
15601557
provider_name = cluster_yaml_utils.get_provider_name(ray_config)
15611558
cloud = cloud_registry.CLOUD_REGISTRY.from_str(provider_name)
15621559
assert cloud is not None, provider_name
15631560

1564-
if (cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT and
1565-
not tpu_utils.is_tpu_vm(handle.launched_resources)):
1561+
if cloud.PROVISIONER_VERSION >= clouds.ProvisionerVersion.SKYPILOT:
15661562
metadata = provision_lib.get_cluster_info(
15671563
provider_name, ray_config['provider']['region'],
15681564
ray_config['cluster_name'], ray_config['provider'])
@@ -1571,20 +1567,6 @@ def get_node_ips(cluster_yaml: str,
15711567
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
15721568
return metadata.get_feasible_ips(get_internal_ips)
15731569

1574-
use_tpu_vm = ray_config['provider'].get('_has_tpus', False)
1575-
if use_tpu_vm:
1576-
assert expected_num_nodes == 1, (
1577-
'TPU VM only supports single node for now.')
1578-
assert handle is not None, 'handle is required for TPU VM.'
1579-
try:
1580-
ips = _get_tpu_vm_pod_ips(ray_config, get_internal_ips)
1581-
except exceptions.CommandError as e:
1582-
raise exceptions.FetchIPError(
1583-
exceptions.FetchIPError.Reason.HEAD) from e
1584-
if len(ips) != tpu_utils.get_num_tpu_devices(handle.launched_resources):
1585-
raise exceptions.FetchIPError(exceptions.FetchIPError.Reason.HEAD)
1586-
return ips
1587-
15881570
if get_internal_ips:
15891571
with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
15901572
ray_config['provider']['use_internal_ips'] = True
@@ -1658,74 +1640,6 @@ def get_node_ips(cluster_yaml: str,
16581640
return head_ip_list + worker_ips
16591641

16601642

1661-
@timeline.event
1662-
def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
1663-
get_internal_ips: bool = False) -> List[str]:
1664-
"""Returns the IPs of all TPU VM Pod workers using gcloud."""
1665-
1666-
cluster_name = ray_config['cluster_name']
1667-
zone = ray_config['provider']['availability_zone']
1668-
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
1669-
f'"(labels.ray-cluster-name={cluster_name})" '
1670-
f'--zone={zone} --format="value(name)"')
1671-
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
1672-
'/dev/null',
1673-
shell=True,
1674-
stream_logs=False,
1675-
require_outputs=True)
1676-
subprocess_utils.handle_returncode(
1677-
returncode,
1678-
query_cmd,
1679-
'Failed to run gcloud to get TPU VM IDs.',
1680-
stderr=stdout + stderr)
1681-
if len(stdout) == 0:
1682-
logger.debug('No TPU VMs found with cluster name '
1683-
f'{cluster_name} in zone {zone}.')
1684-
if len(stdout.splitlines()) > 1:
1685-
# Rare case, this could mean resource leakage. Hint user.
1686-
logger.warning('Found more than one TPU VM/Pod with the same cluster '
1687-
f'name {cluster_name} in zone {zone}.')
1688-
1689-
all_ips = []
1690-
for tpu_id in stdout.splitlines():
1691-
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
1692-
f' --zone {zone} --format=json')
1693-
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
1694-
os.devnull,
1695-
shell=True,
1696-
stream_logs=False,
1697-
require_outputs=True)
1698-
subprocess_utils.handle_returncode(
1699-
returncode,
1700-
tpuvm_cmd,
1701-
'Failed to run gcloud tpu-vm describe.',
1702-
stderr=stdout + stderr)
1703-
1704-
tpuvm_json = json.loads(stdout)
1705-
if tpuvm_json['state'] != 'READY':
1706-
# May be a leaked preempted resource, or terminated by user in the
1707-
# console, or still in the process of being created.
1708-
ux_utils.console_newline()
1709-
logger.debug(f'TPU VM {tpu_id} is in {tpuvm_json["state"]} '
1710-
'state. Skipping IP query... '
1711-
'Hint: make sure it is not leaked.')
1712-
continue
1713-
1714-
ips = []
1715-
for endpoint in tpuvm_json['networkEndpoints']:
1716-
# Note: if TPU VM is being preempted, its IP field may not exist.
1717-
# We use get() to avoid KeyError.
1718-
if get_internal_ips:
1719-
ip = endpoint.get('ipAddress', None)
1720-
else:
1721-
ip = endpoint['accessConfig'].get('externalIp', None)
1722-
if ip is not None:
1723-
ips.append(ip)
1724-
all_ips.extend(ips)
1725-
1726-
return all_ips
1727-
1728-
17291643
def check_network_connection():
17301644
# Tolerate 3 retries as it is observed that connections can fail.
17311645
adapter = adapters.HTTPAdapter(max_retries=retry_lib.Retry(total=3))
@@ -2014,9 +1928,8 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
20141928
# in the worst case we time out in the `ray status` SSH command
20151929
# below.
20161930
external_ips = handle.cached_external_ips
2017-
# This happens to a stopped TPU VM as we use gcloud to query the IP.
2018-
# Or user interrupt the `sky launch` process before the first time
2019-
# resources handle is written back to local database.
1931+
# This happens when user interrupt the `sky launch` process before
1932+
# the first time resources handle is written back to local database.
20201933
# This is helpful when user interrupt after the provision is done
20211934
# and before the skylet is restarted. After #2304 is merged, this
20221935
# helps keep the cluster status to INIT after `sky status -r`, so
@@ -2054,13 +1967,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
20541967
f'-- stdout --\n{output}\n-- stderr --\n{stderr}')
20551968

20561969
ready_head, ready_workers = _count_healthy_nodes_from_ray(output)
2057-
2058-
if ready_head + ready_workers == handle.launched_nodes:
1970+
total_nodes = handle.launched_nodes * handle.num_node_ips
1971+
if ready_head + ready_workers == total_nodes:
20591972
return True
20601973
raise RuntimeError(
20611974
f'Refreshing status ({cluster_name!r}): ray status not showing '
20621975
f'all nodes ({ready_head + ready_workers}/'
2063-
f'{handle.launched_nodes}); output: {output}; stderr: {stderr}')
1976+
f'{total_nodes}); output: {output}; stderr: {stderr}')
20641977
except exceptions.FetchIPError:
20651978
logger.debug(
20661979
f'Refreshing status ({cluster_name!r}) failed to get IPs.')

0 commit comments

Comments
 (0)