|
13 | 13 |
|
14 | 14 | from launch.api_client import ApiClient, Configuration
|
15 | 15 | from launch.api_client.api.default_api import DefaultApi
|
16 |
| -from launch.api_client.model.create_model_bundle_request import CreateModelBundleRequest |
17 |
| -from launch.api_client.model.create_model_endpoint_request import CreateModelEndpointRequest |
18 |
| -from launch.api_client.model.endpoint_predict_request import EndpointPredictRequest |
19 |
| -from launch.api_client.model.update_model_endpoint_request import UpdateModelEndpointRequest |
| 16 | +from launch.api_client.model.create_model_bundle_request import ( |
| 17 | + CreateModelBundleRequest, |
| 18 | +) |
| 19 | +from launch.api_client.model.create_model_endpoint_request import ( |
| 20 | + CreateModelEndpointRequest, |
| 21 | +) |
| 22 | +from launch.api_client.model.endpoint_predict_request import ( |
| 23 | + EndpointPredictRequest, |
| 24 | +) |
| 25 | +from launch.api_client.model.update_model_endpoint_request import ( |
| 26 | + UpdateModelEndpointRequest, |
| 27 | +) |
20 | 28 | from launch.connection import Connection
|
21 | 29 | from launch.constants import (
|
22 | 30 | BATCH_TASK_INPUT_SIGNED_URL_PATH,
|
@@ -655,7 +663,10 @@ def create_model_endpoint(
|
655 | 663 | logger.info("Creating new endpoint")
|
656 | 664 | with ApiClient(self.configuration) as api_client:
|
657 | 665 | api_instance = DefaultApi(api_client)
|
658 |
| - if not isinstance(model_bundle, ModelBundle) or model_bundle.id is None: |
| 666 | + if ( |
| 667 | + not isinstance(model_bundle, ModelBundle) |
| 668 | + or model_bundle.id is None |
| 669 | + ): |
659 | 670 | model_bundle = self.get_model_bundle(model_bundle)
|
660 | 671 | create_model_endpoint_request = CreateModelEndpointRequest(
|
661 | 672 | cpus=cpus,
|
@@ -760,19 +771,27 @@ def edit_model_endpoint(
|
760 | 771 |
|
761 | 772 | if model_bundle is None:
|
762 | 773 | model_bundle_id = None
|
763 |
| - elif isinstance(model_bundle, ModelBundle) and model_bundle.id is not None: |
| 774 | + elif ( |
| 775 | + isinstance(model_bundle, ModelBundle) |
| 776 | + and model_bundle.id is not None |
| 777 | + ): |
764 | 778 | model_bundle_id = model_bundle.id
|
765 | 779 | else:
|
766 | 780 | model_bundle = self.get_model_bundle(model_bundle)
|
767 | 781 | model_bundle_id = model_bundle.id
|
768 | 782 |
|
769 | 783 | if model_endpoint is None:
|
770 | 784 | model_endpoint_id = None
|
771 |
| - elif isinstance(model_endpoint, ModelEndpoint) and model_endpoint.id is not None: |
| 785 | + elif ( |
| 786 | + isinstance(model_endpoint, ModelEndpoint) |
| 787 | + and model_endpoint.id is not None |
| 788 | + ): |
772 | 789 | model_endpoint_id = model_endpoint.id
|
773 | 790 | else:
|
774 | 791 | endpoint_name = _model_endpoint_to_name(model_endpoint)
|
775 |
| - model_endpoint = self.get_model_endpoint(endpoint_name).model_endpoint |
| 792 | + model_endpoint = self.get_model_endpoint( |
| 793 | + endpoint_name |
| 794 | + ).model_endpoint |
776 | 795 | model_endpoint_id = model_endpoint.id
|
777 | 796 |
|
778 | 797 | update_model_endpoint_request = UpdateModelEndpointRequest(
|
@@ -988,7 +1007,9 @@ def _sync_request(
|
988 | 1007 | validate_task_request(url=url, args=args)
|
989 | 1008 | with ApiClient(self.configuration) as api_client:
|
990 | 1009 | api_instance = DefaultApi(api_client)
|
991 |
| - request = EndpointPredictRequest(return_pickled=return_pickled, url=url, args=args) |
| 1010 | + request = EndpointPredictRequest( |
| 1011 | + return_pickled=return_pickled, url=url, args=args |
| 1012 | + ) |
992 | 1013 | resp = api_instance.create_sync_inference_task_v1_sync_tasks_post(
|
993 | 1014 | model_endpoint_id=endpoint_id,
|
994 | 1015 | endpoint_predict_request=request,
|
@@ -1033,7 +1054,9 @@ def _async_request(
|
1033 | 1054 | endpoint = self.get_model_endpoint(endpoint_name)
|
1034 | 1055 | with ApiClient(self.configuration) as api_client:
|
1035 | 1056 | api_instance = DefaultApi(api_client)
|
1036 |
| - request = EndpointPredictRequest(return_pickled=return_pickled, url=url, args=args) |
| 1057 | + request = EndpointPredictRequest( |
| 1058 | + return_pickled=return_pickled, url=url, args=args |
| 1059 | + ) |
1037 | 1060 | resp = api_instance.create_sync_inference_task_v1_sync_tasks_post(
|
1038 | 1061 | model_endpoint_id=endpoint.model_endpoint.id,
|
1039 | 1062 | endpoint_predict_request=request,
|
|
0 commit comments