Skip to content

Commit 7dee94f

Browse files
committed
fixes
1 parent 44efbf4 commit 7dee94f

File tree

3 files changed

+48
-36
lines changed

3 files changed

+48
-36
lines changed

launch/api_client/api/default_api.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, api_client=None):
8787
settings={
8888
"response_type": (CreateAsyncTaskResponse,),
8989
"auth": ["HTTPBasic"],
90-
"endpoint_path": "/async-tasks",
90+
"endpoint_path": "/v1/async-tasks",
9191
"operation_id": "create_async_inference_task_v1_async_tasks_post",
9292
"http_method": "POST",
9393
"servers": None,
@@ -135,7 +135,7 @@ def __init__(self, api_client=None):
135135
settings={
136136
"response_type": (CreateModelBundleResponse,),
137137
"auth": ["HTTPBasic"],
138-
"endpoint_path": "/model-bundles",
138+
"endpoint_path": "/v1/model-bundles",
139139
"operation_id": "create_model_bundle_v1_model_bundles_post",
140140
"http_method": "POST",
141141
"servers": None,
@@ -178,7 +178,7 @@ def __init__(self, api_client=None):
178178
settings={
179179
"response_type": (CreateModelEndpointResponse,),
180180
"auth": ["HTTPBasic"],
181-
"endpoint_path": "/model-endpoints",
181+
"endpoint_path": "/v1/model-endpoints",
182182
"operation_id": "create_model_endpoint_v1_model_endpoints_post",
183183
"http_method": "POST",
184184
"servers": None,
@@ -223,7 +223,7 @@ def __init__(self, api_client=None):
223223
settings={
224224
"response_type": (SyncEndpointPredictResponse,),
225225
"auth": ["HTTPBasic"],
226-
"endpoint_path": "/sync-tasks",
226+
"endpoint_path": "/v1/sync-tasks",
227227
"operation_id": "create_sync_inference_task_v1_sync_tasks_post",
228228
"http_method": "POST",
229229
"servers": None,
@@ -271,7 +271,7 @@ def __init__(self, api_client=None):
271271
settings={
272272
"response_type": (DeleteModelEndpointResponse,),
273273
"auth": ["HTTPBasic"],
274-
"endpoint_path": "/model-endpoints/{model_endpoint_id}",
274+
"endpoint_path": "/v1/model-endpoints/{model_endpoint_id}",
275275
"operation_id": "delete_model_endpoint_v1_model_endpoints_model_endpoint_id_delete",
276276
"http_method": "DELETE",
277277
"servers": None,
@@ -315,7 +315,7 @@ def __init__(self, api_client=None):
315315
settings={
316316
"response_type": (GetAsyncTaskResponse,),
317317
"auth": ["HTTPBasic"],
318-
"endpoint_path": "/async-tasks/{task_id}",
318+
"endpoint_path": "/v1/async-tasks/{task_id}",
319319
"operation_id": "get_async_inference_task_v1_async_tasks_task_id_get",
320320
"http_method": "GET",
321321
"servers": None,
@@ -359,7 +359,7 @@ def __init__(self, api_client=None):
359359
settings={
360360
"response_type": (ModelBundleResponse,),
361361
"auth": ["HTTPBasic"],
362-
"endpoint_path": "/model-bundles/latest",
362+
"endpoint_path": "/v1/model-bundles/latest",
363363
"operation_id": "get_latest_model_bundle_v1_model_bundles_latest_get",
364364
"http_method": "GET",
365365
"servers": None,
@@ -403,7 +403,7 @@ def __init__(self, api_client=None):
403403
settings={
404404
"response_type": (ModelBundleResponse,),
405405
"auth": ["HTTPBasic"],
406-
"endpoint_path": "/model-bundles/{model_bundle_id}",
406+
"endpoint_path": "/v1/model-bundles/{model_bundle_id}",
407407
"operation_id": "get_model_bundle_v1_model_bundles_model_bundle_id_get",
408408
"http_method": "GET",
409409
"servers": None,
@@ -447,7 +447,7 @@ def __init__(self, api_client=None):
447447
settings={
448448
"response_type": (GetModelEndpointResponse,),
449449
"auth": ["HTTPBasic"],
450-
"endpoint_path": "/model-endpoints/{model_endpoint_id}",
450+
"endpoint_path": "/v1/model-endpoints/{model_endpoint_id}",
451451
"operation_id": "get_model_endpoint_v1_model_endpoints_model_endpoint_id_get",
452452
"http_method": "GET",
453453
"servers": None,
@@ -611,7 +611,7 @@ def __init__(self, api_client=None):
611611
settings={
612612
"response_type": (ListModelBundlesResponse,),
613613
"auth": ["HTTPBasic"],
614-
"endpoint_path": "/model-bundles",
614+
"endpoint_path": "/v1/model-bundles",
615615
"operation_id": "list_model_bundles_v1_model_bundles_get",
616616
"http_method": "GET",
617617
"servers": None,
@@ -657,7 +657,7 @@ def __init__(self, api_client=None):
657657
settings={
658658
"response_type": (ListModelEndpointsResponse,),
659659
"auth": ["HTTPBasic"],
660-
"endpoint_path": "/model-endpoints",
660+
"endpoint_path": "/v1/model-endpoints",
661661
"operation_id": "list_model_endpoints_v1_model_endpoints_get",
662662
"http_method": "GET",
663663
"servers": None,
@@ -703,7 +703,7 @@ def __init__(self, api_client=None):
703703
settings={
704704
"response_type": (UpdateModelEndpointResponse,),
705705
"auth": ["HTTPBasic"],
706-
"endpoint_path": "/model-endpoints/{model_endpoint_id}",
706+
"endpoint_path": "/v1/model-endpoints/{model_endpoint_id}",
707707
"operation_id": "update_model_endpoint_v1_model_endpoints_model_endpoint_id_put",
708708
"http_method": "PUT",
709709
"servers": None,

launch/client.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ def __init__(
115115
endpoint: The Scale Launch Endpoint (this should not need to be changed)
116116
self_hosted: True iff you are connecting to a self-hosted Scale Launch
117117
"""
118-
self.connection = Connection(api_key, endpoint or SCALE_LAUNCH_ENDPOINT)
118+
self.connection = Connection(
119+
api_key, endpoint or SCALE_LAUNCH_ENDPOINT
120+
)
119121
self.self_hosted = self_hosted
120122
self.upload_bundle_fn: Optional[Callable[[str, str], None]] = None
121123
self.upload_batch_csv_fn: Optional[Callable[[str, str], None]] = None
@@ -125,7 +127,7 @@ def __init__(
125127
self.bundle_location_fn: Optional[Callable[[], str]] = None
126128
self.batch_csv_location_fn: Optional[Callable[[], str]] = None
127129
self.configuration = Configuration(
128-
host=endpoint,
130+
host=endpoint, # host="host.docker.internal:3000/v1/launch",
129131
discard_unknown_keys=True,
130132
username=api_key,
131133
password="",
@@ -851,9 +853,13 @@ def get_model_endpoint(
851853
resp = resp.model_endpoints[0]
852854

853855
if resp["endpoint_type"].value == "async":
854-
return AsyncEndpoint(ModelEndpoint.from_dict(resp.to_dict()), client=self) # type: ignore
856+
return AsyncEndpoint(
857+
ModelEndpoint.from_dict(resp.to_dict()), client=self # type: ignore
858+
)
855859
elif resp["endpoint_type"].value == "sync":
856-
return SyncEndpoint(ModelEndpoint.from_dict(resp.to_dict()), client=self) # type: ignore
860+
return SyncEndpoint(
861+
ModelEndpoint.from_dict(resp.to_dict()), client=self # type: ignore
862+
)
857863
else:
858864
raise ValueError(
859865
"Endpoint should be one of the types 'sync' or 'async'"
@@ -1071,12 +1077,17 @@ def _async_request(
10711077
with ApiClient(self.configuration) as api_client:
10721078
api_instance = DefaultApi(api_client)
10731079
request = EndpointPredictRequest(
1074-
return_pickled=return_pickled, url=url, args=args
1080+
return_pickled=return_pickled,
1081+
url=url,
1082+
args=args,
1083+
_check_type=False,
10751084
)
10761085
model_endpoint_id = endpoint.model_endpoint.id # type: ignore
1077-
resp = api_instance.create_sync_inference_task_v1_sync_tasks_post(
1078-
model_endpoint_id=model_endpoint_id,
1079-
endpoint_predict_request=request,
1086+
resp = (
1087+
api_instance.create_async_inference_task_v1_async_tasks_post(
1088+
model_endpoint_id=model_endpoint_id,
1089+
endpoint_predict_request=request,
1090+
)
10801091
)
10811092
return resp
10821093

@@ -1098,7 +1109,7 @@ def _get_async_endpoint_response(
10981109
10991110
The dictionary's keys are as follows:
11001111
1101-
- ``state``: ``'PENDING'`` or ``'SUCCESS'`` or ``'FAILURE'``
1112+
- ``status``: ``'PENDING'`` or ``'SUCCESS'`` or ``'FAILURE'``
11021113
- ``result_url``: a url pointing to inference results. This url is accessible for 12 hours after the request has been made.
11031114
- ``result``: the value returned by the endpoint's `predict` function, serialized as json
11041115
@@ -1107,7 +1118,7 @@ def _get_async_endpoint_response(
11071118
.. code-block:: json
11081119
11091120
{
1110-
'state': 'SUCCESS',
1121+
'status': 'SUCCESS',
11111122
'result_url': 'https://foo.s3.us-west-2.amazonaws.com/bar/baz/qux?xyzzy'
11121123
}
11131124

launch/model_endpoint.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,29 @@ def get(self) -> EndpointResponse:
195195
async_response = self.client._get_async_endpoint_response( # pylint: disable=W0212
196196
self.endpoint_name, self.async_task_id
197197
)
198-
if async_response["state"] == "PENDING":
198+
status = async_response["status"].value
199+
if status == "PENDING":
199200
time.sleep(2)
200201
else:
201-
if async_response["state"] == "SUCCESS":
202+
if status == "SUCCESS":
202203
return EndpointResponse(
203204
client=self.client,
204-
status=async_response["state"],
205+
status=status,
205206
result_url=async_response.get("result_url", None),
206207
result=async_response.get("result", None),
207208
traceback=None,
208209
)
209-
elif async_response["state"] == "FAILURE":
210+
elif status == "FAILURE":
210211
return EndpointResponse(
211212
client=self.client,
212-
status=async_response["state"],
213+
status=status,
213214
result_url=None,
214215
result=None,
215216
traceback=async_response.get("traceback", None),
216217
)
217218
else:
218219
raise ValueError(
219-
f"Unrecognized state: {async_response['state']}"
220+
f"Unrecognized status: {async_response['status']}"
220221
)
221222

222223

@@ -291,7 +292,7 @@ def predict(self, request: EndpointRequest) -> EndpointResponse:
291292
)
292293
return EndpointResponse(
293294
client=self.client,
294-
status=raw_response.get("state"),
295+
status=raw_response.get("status"),
295296
result_url=raw_response.get("result_url", None),
296297
result=raw_response.get("result", None),
297298
traceback=raw_response.get("traceback", None),
@@ -450,7 +451,7 @@ def single_request(inner_url, inner_task_id):
450451
return (
451452
inner_url,
452453
inner_task_id,
453-
inner_response.get("state", None),
454+
inner_response.get("status", None),
454455
inner_response,
455456
)
456457

@@ -464,13 +465,13 @@ def single_request(inner_url, inner_task_id):
464465
for response in responses:
465466
if response is None:
466467
continue
467-
url, _, state, raw_response = response
468-
if state:
469-
self.statuses[url] = state
468+
url, _, status, raw_response = response
469+
if status:
470+
self.statuses[url] = status
470471
if raw_response:
471472
response_object = EndpointResponse(
472473
client=self.client,
473-
status=raw_response["state"],
474+
status=raw_response["status"],
474475
result_url=raw_response.get("result_url", None),
475476
result=raw_response.get("result", None),
476477
traceback=raw_response.get("traceback", None),
@@ -479,10 +480,10 @@ def single_request(inner_url, inner_task_id):
479480

480481
def is_done(self, poll=True) -> bool:
481482
"""
482-
Checks the client local state to see if all requests are done.
483+
Checks the client local status to see if all requests are done.
483484
484485
Parameters:
485-
poll: If ``True``, then this will first check the state for a subset
486+
poll: If ``True``, then this will first check the status for a subset
486487
of the remaining incomplete tasks on the Launch server.
487488
"""
488489
# TODO: make some request to some endpoint

0 commit comments

Comments
 (0)