1
1
"""Util constants/functions for the backends."""
2
2
from datetime import datetime
3
3
import enum
4
- import json
5
4
import os
6
5
import pathlib
7
6
import pprint
40
39
from sky import status_lib
41
40
from sky .backends import onprem_utils
42
41
from sky .clouds import cloud_registry
42
+ from sky .clouds .utils import gcp_utils
43
43
from sky .provision import instance_setup
44
44
from sky .skylet import constants
45
- from sky .skylet import log_lib
46
45
from sky .usage import usage_lib
47
46
from sky .utils import cluster_yaml_utils
48
47
from sky .utils import command_runner
52
51
from sky .utils import rich_utils
53
52
from sky .utils import subprocess_utils
54
53
from sky .utils import timeline
55
- from sky .utils import tpu_utils
56
54
from sky .utils import ux_utils
57
55
58
56
if typing .TYPE_CHECKING :
@@ -1107,7 +1105,7 @@ def write_cluster_config(
1107
1105
usage_lib .messages .usage .update_ray_yaml (yaml_path )
1108
1106
1109
1107
# For TPU nodes. TPU VMs do not need TPU_NAME.
1110
- if tpu_utils .is_tpu (to_provision ) and not tpu_utils .is_tpu_vm (to_provision ):
1108
+ if gcp_utils .is_tpu (to_provision ) and not gcp_utils .is_tpu_vm (to_provision ):
1111
1109
tpu_name = resources_vars .get ('tpu_name' )
1112
1110
if tpu_name is None :
1113
1111
tpu_name = cluster_name
@@ -1528,19 +1526,22 @@ def _query_head_ip_with_retries(cluster_yaml: str,
1528
1526
1529
1527
1530
1528
@timeline .event
1531
- def get_node_ips (cluster_yaml : str ,
1532
- expected_num_nodes : int ,
1533
- handle : 'cloud_vm_ray_backend.CloudVmRayResourceHandle' ,
1534
- head_ip_max_attempts : int = 1 ,
1535
- worker_ip_max_attempts : int = 1 ,
1536
- get_internal_ips : bool = False ) -> List [str ]:
1529
+ def get_node_ips (
1530
+ cluster_yaml : str ,
1531
+ expected_num_nodes : int ,
1532
+ # TODO: remove this argument once we remove the legacy on-prem
1533
+ # support.
1534
+ handle : 'cloud_vm_ray_backend.CloudVmRayResourceHandle' ,
1535
+ head_ip_max_attempts : int = 1 ,
1536
+ worker_ip_max_attempts : int = 1 ,
1537
+ get_internal_ips : bool = False ) -> List [str ]:
1537
1538
"""Returns the IPs of all nodes in the cluster, with head node at front.
1538
1539
1539
1540
Args:
1540
1541
cluster_yaml: Path to the cluster yaml.
1541
1542
expected_num_nodes: Expected number of nodes in the cluster.
1542
- handle: Cloud VM Ray resource handle. It is only required for TPU VM or
1543
- on-prem clusters.
1543
+ handle: Cloud VM Ray resource handle. It is only required for on-prem
1544
+ clusters.
1544
1545
head_ip_max_attempts: Max attempts to get head ip.
1545
1546
worker_ip_max_attempts: Max attempts to get worker ips.
1546
1547
get_internal_ips: Whether to get internal IPs. When False, it is still
@@ -1551,18 +1552,13 @@ def get_node_ips(cluster_yaml: str,
1551
1552
exceptions.FetchIPError: if we failed to get the IPs. e.reason is
1552
1553
HEAD or WORKER.
1553
1554
"""
1554
- # When ray up launches TPU VM Pod, Pod workers (except for the head)
1555
- # won't be connected to Ray cluster. Thus "ray get-worker-ips"
1556
- # won't work and we need to query the node IPs with gcloud as
1557
- # implmented in _get_tpu_vm_pod_ips.
1558
1555
ray_config = common_utils .read_yaml (cluster_yaml )
1559
1556
# Use the new provisioner for AWS.
1560
1557
provider_name = cluster_yaml_utils .get_provider_name (ray_config )
1561
1558
cloud = cloud_registry .CLOUD_REGISTRY .from_str (provider_name )
1562
1559
assert cloud is not None , provider_name
1563
1560
1564
- if (cloud .PROVISIONER_VERSION >= clouds .ProvisionerVersion .SKYPILOT and
1565
- not tpu_utils .is_tpu_vm (handle .launched_resources )):
1561
+ if cloud .PROVISIONER_VERSION >= clouds .ProvisionerVersion .SKYPILOT :
1566
1562
metadata = provision_lib .get_cluster_info (
1567
1563
provider_name , ray_config ['provider' ]['region' ],
1568
1564
ray_config ['cluster_name' ], ray_config ['provider' ])
@@ -1571,20 +1567,6 @@ def get_node_ips(cluster_yaml: str,
1571
1567
raise exceptions .FetchIPError (exceptions .FetchIPError .Reason .HEAD )
1572
1568
return metadata .get_feasible_ips (get_internal_ips )
1573
1569
1574
- use_tpu_vm = ray_config ['provider' ].get ('_has_tpus' , False )
1575
- if use_tpu_vm :
1576
- assert expected_num_nodes == 1 , (
1577
- 'TPU VM only supports single node for now.' )
1578
- assert handle is not None , 'handle is required for TPU VM.'
1579
- try :
1580
- ips = _get_tpu_vm_pod_ips (ray_config , get_internal_ips )
1581
- except exceptions .CommandError as e :
1582
- raise exceptions .FetchIPError (
1583
- exceptions .FetchIPError .Reason .HEAD ) from e
1584
- if len (ips ) != tpu_utils .get_num_tpu_devices (handle .launched_resources ):
1585
- raise exceptions .FetchIPError (exceptions .FetchIPError .Reason .HEAD )
1586
- return ips
1587
-
1588
1570
if get_internal_ips :
1589
1571
with tempfile .NamedTemporaryFile (mode = 'w' , delete = False ) as f :
1590
1572
ray_config ['provider' ]['use_internal_ips' ] = True
@@ -1658,74 +1640,6 @@ def get_node_ips(cluster_yaml: str,
1658
1640
return head_ip_list + worker_ips
1659
1641
1660
1642
1661
- @timeline .event
1662
- def _get_tpu_vm_pod_ips (ray_config : Dict [str , Any ],
1663
- get_internal_ips : bool = False ) -> List [str ]:
1664
- """Returns the IPs of all TPU VM Pod workers using gcloud."""
1665
-
1666
- cluster_name = ray_config ['cluster_name' ]
1667
- zone = ray_config ['provider' ]['availability_zone' ]
1668
- query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
1669
- f'"(labels.ray-cluster-name={ cluster_name } )" '
1670
- f'--zone={ zone } --format="value(name)"' )
1671
- returncode , stdout , stderr = log_lib .run_with_log (query_cmd ,
1672
- '/dev/null' ,
1673
- shell = True ,
1674
- stream_logs = False ,
1675
- require_outputs = True )
1676
- subprocess_utils .handle_returncode (
1677
- returncode ,
1678
- query_cmd ,
1679
- 'Failed to run gcloud to get TPU VM IDs.' ,
1680
- stderr = stdout + stderr )
1681
- if len (stdout ) == 0 :
1682
- logger .debug ('No TPU VMs found with cluster name '
1683
- f'{ cluster_name } in zone { zone } .' )
1684
- if len (stdout .splitlines ()) > 1 :
1685
- # Rare case, this could mean resource leakage. Hint user.
1686
- logger .warning ('Found more than one TPU VM/Pod with the same cluster '
1687
- f'name { cluster_name } in zone { zone } .' )
1688
-
1689
- all_ips = []
1690
- for tpu_id in stdout .splitlines ():
1691
- tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe { tpu_id } '
1692
- f' --zone { zone } --format=json' )
1693
- returncode , stdout , stderr = log_lib .run_with_log (tpuvm_cmd ,
1694
- os .devnull ,
1695
- shell = True ,
1696
- stream_logs = False ,
1697
- require_outputs = True )
1698
- subprocess_utils .handle_returncode (
1699
- returncode ,
1700
- tpuvm_cmd ,
1701
- 'Failed to run gcloud tpu-vm describe.' ,
1702
- stderr = stdout + stderr )
1703
-
1704
- tpuvm_json = json .loads (stdout )
1705
- if tpuvm_json ['state' ] != 'READY' :
1706
- # May be a leaked preempted resource, or terminated by user in the
1707
- # console, or still in the process of being created.
1708
- ux_utils .console_newline ()
1709
- logger .debug (f'TPU VM { tpu_id } is in { tpuvm_json ["state" ]} '
1710
- 'state. Skipping IP query... '
1711
- 'Hint: make sure it is not leaked.' )
1712
- continue
1713
-
1714
- ips = []
1715
- for endpoint in tpuvm_json ['networkEndpoints' ]:
1716
- # Note: if TPU VM is being preempted, its IP field may not exist.
1717
- # We use get() to avoid KeyError.
1718
- if get_internal_ips :
1719
- ip = endpoint .get ('ipAddress' , None )
1720
- else :
1721
- ip = endpoint ['accessConfig' ].get ('externalIp' , None )
1722
- if ip is not None :
1723
- ips .append (ip )
1724
- all_ips .extend (ips )
1725
-
1726
- return all_ips
1727
-
1728
-
1729
1643
def check_network_connection ():
1730
1644
# Tolerate 3 retries as it is observed that connections can fail.
1731
1645
adapter = adapters .HTTPAdapter (max_retries = retry_lib .Retry (total = 3 ))
@@ -2014,9 +1928,8 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
2014
1928
# in the worst case we time out in the `ray status` SSH command
2015
1929
# below.
2016
1930
external_ips = handle .cached_external_ips
2017
- # This happens to a stopped TPU VM as we use gcloud to query the IP.
2018
- # Or user interrupt the `sky launch` process before the first time
2019
- # resources handle is written back to local database.
1931
+ # This happens when user interrupt the `sky launch` process before
1932
+ # the first time resources handle is written back to local database.
2020
1933
# This is helpful when user interrupt after the provision is done
2021
1934
# and before the skylet is restarted. After #2304 is merged, this
2022
1935
# helps keep the cluster status to INIT after `sky status -r`, so
@@ -2054,13 +1967,13 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool:
2054
1967
f'-- stdout --\n { output } \n -- stderr --\n { stderr } ' )
2055
1968
2056
1969
ready_head , ready_workers = _count_healthy_nodes_from_ray (output )
2057
-
2058
- if ready_head + ready_workers == handle . launched_nodes :
1970
+ total_nodes = handle . launched_nodes * handle . num_node_ips
1971
+ if ready_head + ready_workers == total_nodes :
2059
1972
return True
2060
1973
raise RuntimeError (
2061
1974
f'Refreshing status ({ cluster_name !r} ): ray status not showing '
2062
1975
f'all nodes ({ ready_head + ready_workers } /'
2063
- f'{ handle . launched_nodes } ); output: { output } ; stderr: { stderr } ' )
1976
+ f'{ total_nodes } ); output: { output } ; stderr: { stderr } ' )
2064
1977
except exceptions .FetchIPError :
2065
1978
logger .debug (
2066
1979
f'Refreshing status ({ cluster_name !r} ) failed to get IPs.' )
0 commit comments