Skip to content

Commit fced949

Browse files
Add support for AAD auth in data plane ops (#35186)
* add support for aad auth in data plane ops * resolve mypy issues * update changelog
1 parent bc135cb commit fced949

File tree

6 files changed

+91
-11
lines changed

6 files changed

+91
-11
lines changed

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Added Project entity class and YAML support.
99
- Project and Hub operations supported by workspace operations.
1010
- workspace list operation supports type filtering.
11+
- Add support for Microsoft Entra token (`aad_token`) auth in `invoke` and `get-credentials` operations.
1112

1213
### Bugs Fixed
1314

sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@
5656
AML_TOKEN_YAML = "aml_token"
5757
AAD_TOKEN_YAML = "aad_token"
5858
KEY = "key"
59+
AAD_TOKEN = "aadtoken"
60+
AAD_TOKEN_RESOURCE_ENDPOINT = "https://ml.azure.com"
61+
EMPTY_CREDENTIALS_ERROR = (
62+
"Credentials unavailable. Initialize credentials using 'MLClient' for SDK or 'az login' for CLI."
63+
)
5964
DEFAULT_ARM_RETRY_INTERVAL = 60
6065
COMPONENT_TYPE = "type"
6166
TID_FMT = "&tid={}"

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_endpoint/online_endpoint.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from azure.ai.ml.entities._mixins import RestTranslatableMixin
3030
from azure.ai.ml.entities._util import is_compute_in_override, load_from_dict
3131
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
32+
from azure.core.credentials import AccessToken
3233

3334
from ._endpoint_helpers import validate_endpoint_or_deployment_name, validate_identity_type_defined
3435
from .endpoint import Endpoint
@@ -625,3 +626,22 @@ def _to_rest_object(self) -> RestEndpointAuthToken:
625626
refresh_after_time_utc=self.refresh_after_time_utc,
626627
token_type=self.token_type,
627628
)
629+
630+
631+
class EndpointAadToken:
632+
"""Endpoint aad token.
633+
634+
:ivar access_token: Access token for aad authentication.
635+
:vartype access_token: str
636+
:ivar expiry_time_utc: Access token expiry time (UTC).
637+
:vartype expiry_time_utc: float
638+
"""
639+
640+
def __init__(self, obj: AccessToken):
641+
"""Constructor for Endpoint aad token.
642+
643+
:param obj: Access token object
644+
:type obj: AccessToken
645+
"""
646+
self.access_token = obj.token
647+
self.expiry_time_utc = obj.expires_on

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_online_endpoint_operations.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from marshmallow.exceptions import ValidationError as SchemaValidationError
1111

12+
from azure.ai.ml._azure_environments import _resource_to_scopes
1213
from azure.ai.ml._exception_helper import log_and_raise_error
1314
from azure.ai.ml._restclient.v2022_02_01_preview import AzureMachineLearningWorkspaces as ServiceClient022022Preview
1415
from azure.ai.ml._restclient.v2022_02_01_preview.models import KeyType, RegenerateEndpointKeysRequest
@@ -23,11 +24,18 @@
2324
from azure.ai.ml._utils._endpoint_utils import validate_response
2425
from azure.ai.ml._utils._http_utils import HttpPipeline
2526
from azure.ai.ml._utils._logger_utils import OpsLogger
26-
from azure.ai.ml.constants._common import KEY, AzureMLResourceType, LROConfigurations
27+
from azure.ai.ml.constants._common import (
28+
AAD_TOKEN,
29+
AAD_TOKEN_RESOURCE_ENDPOINT,
30+
EMPTY_CREDENTIALS_ERROR,
31+
KEY,
32+
AzureMLResourceType,
33+
LROConfigurations,
34+
)
2735
from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointKeyType
2836
from azure.ai.ml.entities import OnlineDeployment, OnlineEndpoint
2937
from azure.ai.ml.entities._assets import Data
30-
from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAuthKeys, EndpointAuthToken
38+
from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAadToken, EndpointAuthKeys, EndpointAuthToken
3139
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
3240
from azure.ai.ml.operations._local_endpoint_helper import _LocalEndpointHelper
3341
from azure.core.credentials import TokenCredential
@@ -96,7 +104,7 @@ def list(self, *, local: bool = False) -> ItemPaged[OnlineEndpoint]:
96104

97105
@distributed_trace
98106
@monitor_with_activity(ops_logger, "OnlineEndpoint.ListKeys", ActivityType.PUBLICAPI)
99-
def get_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken]:
107+
def get_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]:
100108
"""Get the auth credentials.
101109
102110
:param name: The endpoint name
@@ -344,7 +352,7 @@ def invoke(
344352
keys = self._get_online_credentials(name=endpoint_name, auth_mode=endpoint.properties.auth_mode)
345353
if isinstance(keys, EndpointAuthKeys):
346354
key = keys.primary_key
347-
elif isinstance(keys, EndpointAuthToken):
355+
elif isinstance(keys, (EndpointAuthToken, EndpointAadToken)):
348356
key = keys.access_token
349357
else:
350358
key = ""
@@ -365,7 +373,7 @@ def _get_workspace_location(self) -> str:
365373

366374
def _get_online_credentials(
367375
self, name: str, auth_mode: Optional[str] = None
368-
) -> Union[EndpointAuthKeys, EndpointAuthToken]:
376+
) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]:
369377
if not auth_mode:
370378
endpoint = self._online_operation.get(
371379
resource_group_name=self._resource_group_name,
@@ -384,6 +392,11 @@ def _get_online_credentials(
384392
**self._init_kwargs,
385393
)
386394

395+
if auth_mode is not None and auth_mode.lower() == AAD_TOKEN:
396+
if self._credentials:
397+
return EndpointAadToken(self._credentials.get_token(*_resource_to_scopes(AAD_TOKEN_RESOURCE_ENDPOINT)))
398+
raise Exception(EMPTY_CREDENTIALS_ERROR)
399+
387400
return self._online_operation.get_token(
388401
resource_group_name=self._resource_group_name,
389402
workspace_name=self._workspace_name,

sdk/ml/azure-ai-ml/cspell.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"version": "0.2",
33
"ignoreWords": [
44
"kwoa",
5-
"rslex"
5+
"rslex",
6+
"aadtoken"
67
]
7-
}
8+
}

sdk/ml/azure-ai-ml/tests/online_services/unittests/test_online_endpoints.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,27 @@
33

44
import pytest
55
from pytest_mock import MockFixture
6-
from azure.ai.ml import load_online_endpoint, load_online_deployment
7-
from azure.ai.ml._restclient.v2022_10_01.models import EndpointAuthKeys
6+
7+
from azure.ai.ml import load_online_deployment, load_online_endpoint
8+
from azure.ai.ml._azure_environments import _resource_to_scopes
89
from azure.ai.ml._restclient.v2022_02_01_preview.models import (
910
KubernetesOnlineDeployment as RestKubernetesOnlineDeployment,
1011
)
11-
from azure.ai.ml.entities._util import load_from_dict
1212
from azure.ai.ml._restclient.v2022_02_01_preview.models import (
1313
OnlineDeploymentData,
1414
OnlineDeploymentDetails,
1515
OnlineEndpointData,
1616
)
1717
from azure.ai.ml._restclient.v2022_02_01_preview.models import OnlineEndpointDetails as RestOnlineEndpoint
18+
from azure.ai.ml._restclient.v2022_10_01.models import EndpointAuthKeys
1819
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
19-
from azure.ai.ml.constants._common import AzureMLResourceType, HttpResponseStatusCode
20+
from azure.ai.ml.constants._common import (
21+
AAD_TOKEN_RESOURCE_ENDPOINT,
22+
EMPTY_CREDENTIALS_ERROR,
23+
AzureMLResourceType,
24+
HttpResponseStatusCode,
25+
)
26+
from azure.ai.ml.entities._util import load_from_dict
2027
from azure.ai.ml.operations import (
2128
DatastoreOperations,
2229
EnvironmentOperations,
@@ -284,6 +291,39 @@ def test_online_get_token(
284291
mock_online_endpoint_operations._online_operation.get.assert_called_once()
285292
mock_online_endpoint_operations._online_operation.get_token.assert_called_once()
286293

294+
def test_online_aad_get_token(
295+
self,
296+
mock_online_endpoint_operations: OnlineEndpointOperations,
297+
mock_aml_services_2022_02_01_preview: Mock,
298+
) -> None:
299+
random_name = "random_name"
300+
mock_aml_services_2022_02_01_preview.online_endpoints.get.return_value = OnlineEndpointData(
301+
name=random_name,
302+
location="eastus",
303+
properties=RestOnlineEndpoint(auth_mode="aadtoken"),
304+
)
305+
mock_online_endpoint_operations._credentials = Mock(spec_set=DefaultAzureCredential)
306+
mock_online_endpoint_operations.get_keys(name=random_name)
307+
mock_online_endpoint_operations._online_operation.get.assert_called_once()
308+
mock_online_endpoint_operations._credentials.get_token.assert_called_once_with(
309+
*_resource_to_scopes(AAD_TOKEN_RESOURCE_ENDPOINT)
310+
)
311+
312+
def test_online_aad_get_token_with_empty_credentials(
313+
self,
314+
mock_online_endpoint_operations: OnlineEndpointOperations,
315+
mock_aml_services_2022_02_01_preview: Mock,
316+
) -> None:
317+
random_name = "random_name"
318+
mock_aml_services_2022_02_01_preview.online_endpoints.get.return_value = OnlineEndpointData(
319+
name=random_name,
320+
location="eastus",
321+
properties=RestOnlineEndpoint(auth_mode="aadtoken"),
322+
)
323+
with pytest.raises(Exception) as ex:
324+
mock_online_endpoint_operations.get_keys(name=random_name)
325+
assert EMPTY_CREDENTIALS_ERROR in str(ex)
326+
287327
def test_online_delete(
288328
self,
289329
mock_online_endpoint_operations: OnlineEndpointOperations,

0 commit comments

Comments
 (0)