|
1 | 1 | """Nebius cloud adaptor."""
|
2 | 2 | import os
|
| 3 | +import threading |
| 4 | +from typing import Optional |
3 | 5 |
|
4 | 6 | from sky.adaptors import common
|
| 7 | +from sky.utils import annotations |
| 8 | +from sky.utils import ux_utils |
5 | 9 |
|
6 | 10 | NEBIUS_TENANT_ID_FILENAME = 'NEBIUS_TENANT_ID.txt'
|
7 | 11 | NEBIUS_IAM_TOKEN_FILENAME = 'NEBIUS_IAM_TOKEN.txt'
|
|
12 | 16 | NEBIUS_PROJECT_ID_PATH = '~/.nebius/' + NEBIUS_PROJECT_ID_FILENAME
|
13 | 17 | NEBIUS_CREDENTIALS_PATH = '~/.nebius/' + NEBIUS_CREDENTIALS_FILENAME
|
14 | 18 |
|
| 19 | +DEFAULT_REGION = 'eu-north1' |
| 20 | + |
| 21 | +NEBIUS_PROFILE_NAME = 'nebius' |
| 22 | + |
15 | 23 | MAX_RETRIES_TO_DISK_CREATE = 120
|
16 | 24 | MAX_RETRIES_TO_INSTANCE_STOP = 120
|
17 | 25 | MAX_RETRIES_TO_INSTANCE_START = 120
|
|
23 | 31 | POLL_INTERVAL = 5
|
24 | 32 |
|
25 | 33 | _iam_token = None
|
| 34 | +_sdk = None |
26 | 35 | _tenant_id = None
|
27 | 36 | _project_id = None
|
28 | 37 |
|
| 38 | +_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for Nebius AI Cloud.' |
| 39 | + 'Try pip install "skypilot[nebius]"') |
| 40 | + |
29 | 41 | nebius = common.LazyImport(
|
30 | 42 | 'nebius',
|
31 |
| - import_error_message='Failed to import dependencies for Nebius AI Cloud. ' |
32 |
| - 'Try running: pip install "skypilot[nebius]"', |
| 43 | + import_error_message=_IMPORT_ERROR_MESSAGE, |
33 | 44 | # https://github.com/grpc/grpc/issues/37642 to avoid spam in console
|
34 | 45 | set_loggers=lambda: os.environ.update({'GRPC_VERBOSITY': 'NONE'}))
|
| 46 | +boto3 = common.LazyImport('boto3', import_error_message=_IMPORT_ERROR_MESSAGE) |
| 47 | +botocore = common.LazyImport('botocore', |
| 48 | + import_error_message=_IMPORT_ERROR_MESSAGE) |
| 49 | + |
| 50 | +_LAZY_MODULES = (boto3, botocore, nebius) |
| 51 | +_session_creation_lock = threading.RLock() |
| 52 | +_INDENT_PREFIX = ' ' |
| 53 | +NAME = 'Nebius' |
| 54 | +SKY_CHECK_NAME = 'Nebius (for Nebius Object Storae)' |
35 | 55 |
|
36 | 56 |
|
37 | 57 | def request_error():
|
@@ -104,7 +124,109 @@ def get_tenant_id():
|
104 | 124 |
|
105 | 125 |
|
106 | 126 | def sdk():
|
107 |
| - if get_iam_token() is not None: |
108 |
| - return nebius.sdk.SDK(credentials=get_iam_token()) |
109 |
| - return nebius.sdk.SDK( |
110 |
| - credentials_file_name=os.path.expanduser(NEBIUS_CREDENTIALS_PATH)) |
| 127 | + global _sdk |
| 128 | + if _sdk is None: |
| 129 | + if get_iam_token() is not None: |
| 130 | + _sdk = nebius.sdk.SDK(credentials=get_iam_token()) |
| 131 | + return _sdk |
| 132 | + _sdk = nebius.sdk.SDK( |
| 133 | + credentials_file_name=os.path.expanduser(NEBIUS_CREDENTIALS_PATH)) |
| 134 | + return _sdk |
| 135 | + |
| 136 | + |
| 137 | +def get_nebius_credentials(boto3_session): |
| 138 | + """Gets the Nebius credentials from the boto3 session object. |
| 139 | +
|
| 140 | + Args: |
| 141 | + boto3_session: The boto3 session object. |
| 142 | + Returns: |
| 143 | + botocore.credentials.ReadOnlyCredentials object with the R2 credentials. |
| 144 | + """ |
| 145 | + nebius_credentials = boto3_session.get_credentials() |
| 146 | + if nebius_credentials is None: |
| 147 | + with ux_utils.print_exception_no_traceback(): |
| 148 | + raise ValueError('Nebius credentials not found. Run ' |
| 149 | + '`sky check` to verify credentials are ' |
| 150 | + 'correctly set up.') |
| 151 | + return nebius_credentials.get_frozen_credentials() |
| 152 | + |
| 153 | + |
| 154 | +# lru_cache() is thread-safe and it will return the same session object |
| 155 | +# for different threads. |
| 156 | +# Reference: https://docs.python.org/3/library/functools.html#functools.lru_cache # pylint: disable=line-too-long |
| 157 | +@annotations.lru_cache(scope='global') |
| 158 | +def session(): |
| 159 | + """Create an AWS session.""" |
| 160 | + # Creating the session object is not thread-safe for boto3, |
| 161 | + # so we add a reentrant lock to synchronize the session creation. |
| 162 | + # Reference: https://github.com/boto/boto3/issues/1592 |
| 163 | + # However, the session object itself is thread-safe, so we are |
| 164 | + # able to use lru_cache() to cache the session object. |
| 165 | + with _session_creation_lock: |
| 166 | + session_ = boto3.session.Session(profile_name=NEBIUS_PROFILE_NAME) |
| 167 | + return session_ |
| 168 | + |
| 169 | + |
| 170 | +@annotations.lru_cache(scope='global') |
| 171 | +def resource(resource_name: str, region: str = DEFAULT_REGION, **kwargs): |
| 172 | + """Create a Nebius resource. |
| 173 | +
|
| 174 | + Args: |
| 175 | + resource_name: Nebius resource name (e.g., 's3'). |
| 176 | + kwargs: Other options. |
| 177 | + """ |
| 178 | + # Need to use the resource retrieved from the per-thread session |
| 179 | + # to avoid thread-safety issues (Directly creating the client |
| 180 | + # with boto3.resource() is not thread-safe). |
| 181 | + # Reference: https://stackoverflow.com/a/59635814 |
| 182 | + |
| 183 | + session_ = session() |
| 184 | + nebius_credentials = get_nebius_credentials(session_) |
| 185 | + endpoint = create_endpoint(region) |
| 186 | + |
| 187 | + return session_.resource( |
| 188 | + resource_name, |
| 189 | + endpoint_url=endpoint, |
| 190 | + aws_access_key_id=nebius_credentials.access_key, |
| 191 | + aws_secret_access_key=nebius_credentials.secret_key, |
| 192 | + region_name=region, |
| 193 | + **kwargs) |
| 194 | + |
| 195 | + |
| 196 | +@annotations.lru_cache(scope='global') |
| 197 | +def client(service_name: str, region): |
| 198 | + """Create an Nebius client of a certain service. |
| 199 | +
|
| 200 | + Args: |
| 201 | + service_name: Nebius service name (e.g., 's3'). |
| 202 | + kwargs: Other options. |
| 203 | + """ |
| 204 | + # Need to use the client retrieved from the per-thread session |
| 205 | + # to avoid thread-safety issues (Directly creating the client |
| 206 | + # with boto3.client() is not thread-safe). |
| 207 | + # Reference: https://stackoverflow.com/a/59635814 |
| 208 | + |
| 209 | + session_ = session() |
| 210 | + nebius_credentials = get_nebius_credentials(session_) |
| 211 | + endpoint = create_endpoint(region) |
| 212 | + |
| 213 | + return session_.client(service_name, |
| 214 | + endpoint_url=endpoint, |
| 215 | + aws_access_key_id=nebius_credentials.access_key, |
| 216 | + aws_secret_access_key=nebius_credentials.secret_key, |
| 217 | + region_name=region) |
| 218 | + |
| 219 | + |
| 220 | +@common.load_lazy_modules(_LAZY_MODULES) |
| 221 | +def botocore_exceptions(): |
| 222 | + """AWS botocore exception.""" |
| 223 | + # pylint: disable=import-outside-toplevel |
| 224 | + from botocore import exceptions |
| 225 | + return exceptions |
| 226 | + |
| 227 | + |
| 228 | +def create_endpoint(region: Optional[str] = DEFAULT_REGION) -> str: |
| 229 | + """Reads accountid necessary to interact with Nebius Object Storage""" |
| 230 | + if region is None: |
| 231 | + region = DEFAULT_REGION |
| 232 | + return f'https://storage.{region}.nebius.cloud:443' |
0 commit comments