5
5
import shutil
6
6
import tempfile
7
7
from io import StringIO
8
- from typing import Any , Callable , Dict , List , Optional , TypeVar , Union
8
+ from typing import Any , Callable , Dict , List , Optional , Type , TypeVar , Union
9
9
from zipfile import ZipFile
10
10
11
11
import cloudpickle
12
12
import requests
13
13
import yaml
14
14
from frozendict import frozendict
15
+ from pydantic import BaseModel
15
16
16
17
from launch .api_client import ApiClient , Configuration
17
18
from launch .api_client .apis .tags .default_api import DefaultApi
59
60
ModelEndpoint ,
60
61
SyncEndpoint ,
61
62
)
63
+ from launch .pydantic_schemas import get_model_definitions
62
64
from launch .request_validation import validate_task_request
63
65
64
66
DEFAULT_NETWORK_TIMEOUT_SEC = 120
@@ -224,15 +226,39 @@ def register_endpoint_auth_decorator(self, endpoint_auth_decorator_fn):
224
226
"""
225
227
self .endpoint_auth_decorator_fn = endpoint_auth_decorator_fn
226
228
229
+ def _upload_data (self , data : bytes ) -> str :
230
+ if self .self_hosted :
231
+ if self .upload_bundle_fn is None :
232
+ raise ValueError ("Upload_bundle_fn should be registered" )
233
+ if self .bundle_location_fn is None :
234
+ raise ValueError (
235
+ "Need either bundle_location_fn to know where to upload bundles"
236
+ )
237
+ raw_bundle_url = self .bundle_location_fn () # type: ignore
238
+ self .upload_bundle_fn (data , raw_bundle_url ) # type: ignore
239
+ else :
240
+ model_bundle_url = self .connection .post (
241
+ {}, MODEL_BUNDLE_SIGNED_URL_PATH
242
+ )
243
+ s3_path = model_bundle_url ["signedUrl" ]
244
+ raw_bundle_url = (
245
+ f"s3://{ model_bundle_url ['bucket' ]} /{ model_bundle_url ['key' ]} "
246
+ )
247
+ requests .put (s3_path , data = data )
248
+ return raw_bundle_url
249
+
227
250
def create_model_bundle_from_dirs (
228
251
self ,
252
+ * ,
229
253
model_bundle_name : str ,
230
254
base_paths : List [str ],
231
255
requirements_path : str ,
232
256
env_params : Dict [str , str ],
233
257
load_predict_fn_module_path : str ,
234
258
load_model_fn_module_path : str ,
235
259
app_config : Optional [Union [Dict [str , Any ], str ]] = None ,
260
+ request_schema : Optional [Type [BaseModel ]] = None ,
261
+ response_schema : Optional [Type [BaseModel ]] = None ,
236
262
) -> ModelBundle :
237
263
"""
238
264
Packages up code from one or more local filesystem folders and uploads them as a bundle to Scale Launch.
@@ -302,6 +328,13 @@ def create_model_bundle_from_dirs(
302
328
the function located at load_predict_fn_module_path.
303
329
304
330
app_config: Either a Dictionary that represents a YAML file contents or a local path to a YAML file.
331
+
332
+ request_schema: A pydantic model that represents the request schema for the model
333
+ bundle. This is used to validate the request body for the model bundle's endpoint.
334
+
335
+ response_schema: A pydantic model that represents the request schema for the model
336
+ bundle. This is used to validate the response for the model bundle's endpoint.
337
+ Note: If request_schema is specified, then response_schema must also be specified.
305
338
"""
306
339
with open (requirements_path , "r" , encoding = "utf-8" ) as req_f :
307
340
requirements = req_f .read ().splitlines ()
@@ -315,24 +348,20 @@ def create_model_bundle_from_dirs(
315
348
finally :
316
349
shutil .rmtree (tmpdir )
317
350
318
- if self .self_hosted :
319
- if self .upload_bundle_fn is None :
320
- raise ValueError ("Upload_bundle_fn should be registered" )
321
- if self .bundle_location_fn is None :
322
- raise ValueError (
323
- "Need either bundle_location_fn to know where to upload bundles"
324
- )
325
- raw_bundle_url = self .bundle_location_fn () # type: ignore
326
- self .upload_bundle_fn (data , raw_bundle_url ) # type: ignore
327
- else :
328
- model_bundle_url = self .connection .post (
329
- {}, MODEL_BUNDLE_SIGNED_URL_PATH
351
+ raw_bundle_url = self ._upload_data (data )
352
+
353
+ schema_location = None
354
+ if bool (request_schema ) ^ bool (response_schema ):
355
+ raise ValueError (
356
+ "If request_schema is specified, then response_schema must also be specified."
330
357
)
331
- s3_path = model_bundle_url ["signedUrl" ]
332
- raw_bundle_url = (
333
- f"s3://{ model_bundle_url ['bucket' ]} /{ model_bundle_url ['key' ]} "
358
+ if request_schema is not None and response_schema is not None :
359
+ model_definitions = get_model_definitions (
360
+ request_schema = request_schema ,
361
+ response_schema = response_schema ,
334
362
)
335
- requests .put (s3_path , data = data )
363
+ model_definitions_encoded = json .dumps (model_definitions ).encode ()
364
+ schema_location = self ._upload_data (model_definitions_encoded )
336
365
337
366
bundle_metadata = {
338
367
"load_predict_fn_module_path" : load_predict_fn_module_path ,
@@ -350,6 +379,7 @@ def create_model_bundle_from_dirs(
350
379
bundle_metadata = bundle_metadata ,
351
380
requirements = requirements ,
352
381
env_params = env_params ,
382
+ schema_location = schema_location ,
353
383
)
354
384
_add_app_config_to_bundle_create_payload (payload , app_config )
355
385
@@ -367,6 +397,7 @@ def create_model_bundle_from_dirs(
367
397
packaging_type = ModelBundlePackagingType ("zip" ),
368
398
metadata = bundle_metadata ,
369
399
app_config = payload .get ("app_config" ),
400
+ schema_location = schema_location ,
370
401
)
371
402
create_model_bundle_request = CreateModelBundleRequest (** payload ) # type: ignore
372
403
api_instance .create_model_bundle_v1_model_bundles_post (
@@ -390,6 +421,8 @@ def create_model_bundle( # pylint: disable=too-many-statements
390
421
bundle_url : Optional [str ] = None ,
391
422
app_config : Optional [Union [Dict [str , Any ], str ]] = None ,
392
423
globals_copy : Optional [Dict [str , Any ]] = None ,
424
+ request_schema : Optional [Type [BaseModel ]] = None ,
425
+ response_schema : Optional [Type [BaseModel ]] = None ,
393
426
) -> ModelBundle :
394
427
"""
395
428
Uploads and registers a model bundle to Scale Launch.
@@ -466,6 +499,13 @@ def create_model_bundle( # pylint: disable=too-many-statements
466
499
467
500
bundle_url: (Only used in self-hosted mode.) The desired location of bundle.
468
501
Overrides any value given by ``self.bundle_location_fn``
502
+
503
+ request_schema: A pydantic model that represents the request schema for the model
504
+ bundle. This is used to validate the request body for the model bundle's endpoint.
505
+
506
+ response_schema: A pydantic model that represents the request schema for the model
507
+ bundle. This is used to validate the response for the model bundle's endpoint.
508
+ Note: If request_schema is specified, then response_schema must also be specified.
469
509
"""
470
510
# TODO(ivan): remove `disable=too-many-branches` when get rid of `load_*` functions
471
511
# pylint: disable=too-many-branches
@@ -533,29 +573,20 @@ def create_model_bundle( # pylint: disable=too-many-statements
533
573
)
534
574
535
575
serialized_bundle = cloudpickle .dumps (bundle )
576
+ raw_bundle_url = self ._upload_data (data = serialized_bundle )
536
577
537
- if self .self_hosted :
538
- if self .upload_bundle_fn is None :
539
- raise ValueError ("Upload_bundle_fn should be registered" )
540
- if self .bundle_location_fn is None and bundle_url is None :
541
- raise ValueError (
542
- "Need either bundle_location_fn or bundle_url to know where to upload bundles"
543
- )
544
- if bundle_url is None :
545
- bundle_url = self .bundle_location_fn () # type: ignore
546
- self .upload_bundle_fn (serialized_bundle , bundle_url )
547
- raw_bundle_url = bundle_url
548
- else :
549
- # Grab a signed url to make upload to
550
- model_bundle_s3_url = self .connection .post (
551
- {}, MODEL_BUNDLE_SIGNED_URL_PATH
578
+ schema_location = None
579
+ if bool (request_schema ) ^ bool (response_schema ):
580
+ raise ValueError (
581
+ "If request_schema is specified, then response_schema must also be specified."
552
582
)
553
- s3_path = model_bundle_s3_url ["signedUrl" ]
554
- raw_bundle_url = f"s3://{ model_bundle_s3_url ['bucket' ]} /{ model_bundle_s3_url ['key' ]} "
555
-
556
- # Make bundle upload
557
-
558
- requests .put (s3_path , data = serialized_bundle )
583
+ if request_schema is not None and response_schema is not None :
584
+ model_definitions = get_model_definitions (
585
+ request_schema = request_schema ,
586
+ response_schema = response_schema ,
587
+ )
588
+ model_definitions_encoded = json .dumps (model_definitions ).encode ()
589
+ schema_location = self ._upload_data (model_definitions_encoded )
559
590
560
591
payload = dict (
561
592
packaging_type = "cloudpickle" ,
@@ -564,6 +595,7 @@ def create_model_bundle( # pylint: disable=too-many-statements
564
595
bundle_metadata = bundle_metadata ,
565
596
requirements = requirements ,
566
597
env_params = env_params ,
598
+ schema_location = schema_location ,
567
599
)
568
600
569
601
_add_app_config_to_bundle_create_payload (payload , app_config )
@@ -581,7 +613,8 @@ def create_model_bundle( # pylint: disable=too-many-statements
581
613
requirements = requirements ,
582
614
packaging_type = ModelBundlePackagingType ("cloudpickle" ),
583
615
metadata = bundle_metadata ,
584
- app_config = payload .get ("app_config" ),
616
+ app_config = app_config ,
617
+ schema_location = schema_location ,
585
618
)
586
619
create_model_bundle_request = CreateModelBundleRequest (** payload ) # type: ignore
587
620
api_instance .create_model_bundle_v1_model_bundles_post (
0 commit comments