Skip to content

Commit 71c908c

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into feature/model_group
2 parents 97568d5 + e9aaef2 commit 71c908c

File tree

5 files changed

+75
-9
lines changed

5 files changed

+75
-9
lines changed

ads/aqua/client/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,19 @@ def embeddings(
582582
payload = {**(payload or {}), "input": input}
583583
return self._request(payload=payload, headers=headers)
584584

585+
def fetch_data(self) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
586+
"""Fetch Data in json format by sending a request to the endpoint.
587+
588+
Args:
589+
590+
Returns:
591+
Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: The server's response, typically including the data in JSON format.
592+
"""
593+
# headers = {"Content-Type", "application/json"}
594+
response = self._client.get(self.endpoint)
595+
json_response = response.json()
596+
return json_response
597+
585598

586599
class AsyncClient(BaseClient):
587600
"""

ads/aqua/extension/deployment_handler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,37 @@ def post(self, *args, **kwargs): # noqa: ARG002
373373
)
374374

375375

376+
class AquaModelListHandler(AquaAPIhandler):
377+
"""Handler for Aqua model list params REST APIs.
378+
379+
Methods
380+
-------
381+
get(self, *args, **kwargs)
382+
Validates parameters for the given model id.
383+
"""
384+
385+
@handle_exceptions
386+
def get(self, model_deployment_id):
387+
"""
388+
Handles get model list for the Active Model Deployment
389+
Raises
390+
------
391+
HTTPError
392+
Raises HTTPError if inputs are missing or are invalid
393+
"""
394+
395+
self.set_header("Content-Type", "application/json")
396+
endpoint: str = ""
397+
model_deployment = AquaDeploymentApp().get(model_deployment_id)
398+
endpoint = model_deployment.endpoint.rstrip("/") + "/predict/v1/models"
399+
aqua_client = Client(endpoint=endpoint)
400+
try:
401+
list_model_result = aqua_client.fetch_data()
402+
return self.finish(list_model_result)
403+
except Exception as ex:
404+
raise HTTPError(500, str(ex))
405+
406+
376407
__handlers__ = [
377408
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
378409
("deployments/config/?([^/]*)", AquaDeploymentHandler),
@@ -381,4 +412,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
381412
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
382413
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
383414
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler),
415+
("deployments/models/list/?([^/]*)", AquaModelListHandler),
384416
]

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def test__load_default_properties(self, mock_from_ocid):
373373
"cpu_baseline": None,
374374
"ocpus": 10.0,
375375
"memory_in_gbs": 36.0,
376+
"cpu_baseline": None,
376377
},
377378
ModelDeploymentInfrastructure.CONST_REPLICA: 1,
378379
}

tests/unitary/default_setup/pipeline/test_pipeline.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
nb_session_ocid="ocid1.datasciencenotebooksession.oc1.iad..<unique_ocid>",
3333
shape_name="VM.Standard.E3.Flex",
3434
block_storage_size_in_gbs=100,
35-
shape_config_details={"ocpus": 1, "memory_in_gbs": 16},
35+
shape_config_details={"ocpus": 1.0, "memory_in_gbs": 16.0, "cpu_baseline": None},
3636
)
3737
PIPELINE_OCID = "ocid.xxx.datasciencepipeline.<unique_ocid>"
3838

@@ -334,10 +334,8 @@ def test_pipeline_define(self):
334334
"jobId": "TestJobIdOne",
335335
"description": "Test description one",
336336
"commandLineArguments": "ARGUMENT --KEY VALUE",
337-
"environmentVariables": {
338-
"ENV": "VALUE"
339-
},
340-
"maximumRuntimeInMinutes": 20
337+
"environmentVariables": {"ENV": "VALUE"},
338+
"maximumRuntimeInMinutes": 20,
341339
},
342340
},
343341
{
@@ -1066,10 +1064,8 @@ def test_pipeline_to_dict(self):
10661064
"jobId": "TestJobIdOne",
10671065
"description": "Test description one",
10681066
"commandLineArguments": "ARGUMENT --KEY VALUE",
1069-
"environmentVariables": {
1070-
"ENV": "VALUE"
1071-
},
1072-
"maximumRuntimeInMinutes": 20
1067+
"environmentVariables": {"ENV": "VALUE"},
1068+
"maximumRuntimeInMinutes": 20,
10731069
},
10741070
},
10751071
{

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from parameterized import parameterized
1414

1515
import ads.aqua
16+
from ads.aqua.modeldeployment.entities import AquaDeploymentDetail
1617
import ads.config
1718
from ads.aqua.extension.deployment_handler import (
1819
AquaDeploymentHandler,
1920
AquaDeploymentParamsHandler,
2021
AquaDeploymentStreamingInferenceHandler,
22+
AquaModelListHandler,
2123
)
2224

2325

@@ -260,3 +262,25 @@ def test_post(self, mock_get_model_deployment_response):
260262
self.handler.write.assert_any_call("chunk1")
261263
self.handler.write.assert_any_call("chunk2")
262264
self.handler.finish.assert_called_once()
265+
266+
267+
class AquaModelListHandlerTestCase(unittest.TestCase):
268+
default_params = {
269+
"data": [{"id": "id", "object": "object", "owned_by": "openAI", "created": 124}]
270+
}
271+
272+
@patch.object(IPythonHandler, "__init__")
273+
def setUp(self, ipython_init_mock) -> None:
274+
ipython_init_mock.return_value = None
275+
self.aqua_model_list_handler = AquaModelListHandler(MagicMock(), MagicMock())
276+
self.aqua_model_list_handler._headers = MagicMock()
277+
278+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.get")
279+
@patch("notebook.base.handlers.APIHandler.finish")
280+
def test_get_model_list(self, mock_get, mock_finish):
281+
"""Test to check the handler get method to return model list."""
282+
283+
mock_get.return_value = MagicMock(id="test_model_id")
284+
mock_finish.side_effect = lambda x: x
285+
result = self.aqua_model_list_handler.get(model_id="test_model_id")
286+
mock_get.assert_called()

0 commit comments

Comments
 (0)