Skip to content

Commit 4dfe7d2

Browse files
authored
Phil/upload schemas (#68)
* update client * try again * fix import * fixes * wip * fixes * try fix again * fix * fixes * add pydantic schema support * add test * fix * fix flat schemas * black and isort
1 parent 66c0609 commit 4dfe7d2

File tree

6 files changed

+288
-40
lines changed

6 files changed

+288
-40
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ max-line-length=79
2828
[MASTER]
2929
# Ignore anything inside launch/clientlib (since it's documentation)
3030
ignore=clientlib,api_client
31+
extension-pkg-whitelist=pydantic

launch/client.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import shutil
66
import tempfile
77
from io import StringIO
8-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
8+
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
99
from zipfile import ZipFile
1010

1111
import cloudpickle
1212
import requests
1313
import yaml
1414
from frozendict import frozendict
15+
from pydantic import BaseModel
1516

1617
from launch.api_client import ApiClient, Configuration
1718
from launch.api_client.apis.tags.default_api import DefaultApi
@@ -59,6 +60,7 @@
5960
ModelEndpoint,
6061
SyncEndpoint,
6162
)
63+
from launch.pydantic_schemas import get_model_definitions
6264
from launch.request_validation import validate_task_request
6365

6466
DEFAULT_NETWORK_TIMEOUT_SEC = 120
@@ -224,15 +226,39 @@ def register_endpoint_auth_decorator(self, endpoint_auth_decorator_fn):
224226
"""
225227
self.endpoint_auth_decorator_fn = endpoint_auth_decorator_fn
226228

229+
def _upload_data(self, data: bytes) -> str:
230+
if self.self_hosted:
231+
if self.upload_bundle_fn is None:
232+
raise ValueError("Upload_bundle_fn should be registered")
233+
if self.bundle_location_fn is None:
234+
raise ValueError(
235+
"Need either bundle_location_fn to know where to upload bundles"
236+
)
237+
raw_bundle_url = self.bundle_location_fn() # type: ignore
238+
self.upload_bundle_fn(data, raw_bundle_url) # type: ignore
239+
else:
240+
model_bundle_url = self.connection.post(
241+
{}, MODEL_BUNDLE_SIGNED_URL_PATH
242+
)
243+
s3_path = model_bundle_url["signedUrl"]
244+
raw_bundle_url = (
245+
f"s3://{model_bundle_url['bucket']}/{model_bundle_url['key']}"
246+
)
247+
requests.put(s3_path, data=data)
248+
return raw_bundle_url
249+
227250
def create_model_bundle_from_dirs(
228251
self,
252+
*,
229253
model_bundle_name: str,
230254
base_paths: List[str],
231255
requirements_path: str,
232256
env_params: Dict[str, str],
233257
load_predict_fn_module_path: str,
234258
load_model_fn_module_path: str,
235259
app_config: Optional[Union[Dict[str, Any], str]] = None,
260+
request_schema: Optional[Type[BaseModel]] = None,
261+
response_schema: Optional[Type[BaseModel]] = None,
236262
) -> ModelBundle:
237263
"""
238264
Packages up code from one or more local filesystem folders and uploads them as a bundle to Scale Launch.
@@ -302,6 +328,13 @@ def create_model_bundle_from_dirs(
302328
the function located at load_predict_fn_module_path.
303329
304330
app_config: Either a Dictionary that represents a YAML file contents or a local path to a YAML file.
331+
332+
request_schema: A pydantic model that represents the request schema for the model
333+
bundle. This is used to validate the request body for the model bundle's endpoint.
334+
335+
response_schema: A pydantic model that represents the request schema for the model
336+
bundle. This is used to validate the response for the model bundle's endpoint.
337+
Note: If request_schema is specified, then response_schema must also be specified.
305338
"""
306339
with open(requirements_path, "r", encoding="utf-8") as req_f:
307340
requirements = req_f.read().splitlines()
@@ -315,24 +348,20 @@ def create_model_bundle_from_dirs(
315348
finally:
316349
shutil.rmtree(tmpdir)
317350

318-
if self.self_hosted:
319-
if self.upload_bundle_fn is None:
320-
raise ValueError("Upload_bundle_fn should be registered")
321-
if self.bundle_location_fn is None:
322-
raise ValueError(
323-
"Need either bundle_location_fn to know where to upload bundles"
324-
)
325-
raw_bundle_url = self.bundle_location_fn() # type: ignore
326-
self.upload_bundle_fn(data, raw_bundle_url) # type: ignore
327-
else:
328-
model_bundle_url = self.connection.post(
329-
{}, MODEL_BUNDLE_SIGNED_URL_PATH
351+
raw_bundle_url = self._upload_data(data)
352+
353+
schema_location = None
354+
if bool(request_schema) ^ bool(response_schema):
355+
raise ValueError(
356+
"If request_schema is specified, then response_schema must also be specified."
330357
)
331-
s3_path = model_bundle_url["signedUrl"]
332-
raw_bundle_url = (
333-
f"s3://{model_bundle_url['bucket']}/{model_bundle_url['key']}"
358+
if request_schema is not None and response_schema is not None:
359+
model_definitions = get_model_definitions(
360+
request_schema=request_schema,
361+
response_schema=response_schema,
334362
)
335-
requests.put(s3_path, data=data)
363+
model_definitions_encoded = json.dumps(model_definitions).encode()
364+
schema_location = self._upload_data(model_definitions_encoded)
336365

337366
bundle_metadata = {
338367
"load_predict_fn_module_path": load_predict_fn_module_path,
@@ -350,6 +379,7 @@ def create_model_bundle_from_dirs(
350379
bundle_metadata=bundle_metadata,
351380
requirements=requirements,
352381
env_params=env_params,
382+
schema_location=schema_location,
353383
)
354384
_add_app_config_to_bundle_create_payload(payload, app_config)
355385

@@ -367,6 +397,7 @@ def create_model_bundle_from_dirs(
367397
packaging_type=ModelBundlePackagingType("zip"),
368398
metadata=bundle_metadata,
369399
app_config=payload.get("app_config"),
400+
schema_location=schema_location,
370401
)
371402
create_model_bundle_request = CreateModelBundleRequest(**payload) # type: ignore
372403
api_instance.create_model_bundle_v1_model_bundles_post(
@@ -390,6 +421,8 @@ def create_model_bundle( # pylint: disable=too-many-statements
390421
bundle_url: Optional[str] = None,
391422
app_config: Optional[Union[Dict[str, Any], str]] = None,
392423
globals_copy: Optional[Dict[str, Any]] = None,
424+
request_schema: Optional[Type[BaseModel]] = None,
425+
response_schema: Optional[Type[BaseModel]] = None,
393426
) -> ModelBundle:
394427
"""
395428
Uploads and registers a model bundle to Scale Launch.
@@ -466,6 +499,13 @@ def create_model_bundle( # pylint: disable=too-many-statements
466499
467500
bundle_url: (Only used in self-hosted mode.) The desired location of bundle.
468501
Overrides any value given by ``self.bundle_location_fn``
502+
503+
request_schema: A pydantic model that represents the request schema for the model
504+
bundle. This is used to validate the request body for the model bundle's endpoint.
505+
506+
response_schema: A pydantic model that represents the request schema for the model
507+
bundle. This is used to validate the response for the model bundle's endpoint.
508+
Note: If request_schema is specified, then response_schema must also be specified.
469509
"""
470510
# TODO(ivan): remove `disable=too-many-branches` when get rid of `load_*` functions
471511
# pylint: disable=too-many-branches
@@ -533,29 +573,20 @@ def create_model_bundle( # pylint: disable=too-many-statements
533573
)
534574

535575
serialized_bundle = cloudpickle.dumps(bundle)
576+
raw_bundle_url = self._upload_data(data=serialized_bundle)
536577

537-
if self.self_hosted:
538-
if self.upload_bundle_fn is None:
539-
raise ValueError("Upload_bundle_fn should be registered")
540-
if self.bundle_location_fn is None and bundle_url is None:
541-
raise ValueError(
542-
"Need either bundle_location_fn or bundle_url to know where to upload bundles"
543-
)
544-
if bundle_url is None:
545-
bundle_url = self.bundle_location_fn() # type: ignore
546-
self.upload_bundle_fn(serialized_bundle, bundle_url)
547-
raw_bundle_url = bundle_url
548-
else:
549-
# Grab a signed url to make upload to
550-
model_bundle_s3_url = self.connection.post(
551-
{}, MODEL_BUNDLE_SIGNED_URL_PATH
578+
schema_location = None
579+
if bool(request_schema) ^ bool(response_schema):
580+
raise ValueError(
581+
"If request_schema is specified, then response_schema must also be specified."
552582
)
553-
s3_path = model_bundle_s3_url["signedUrl"]
554-
raw_bundle_url = f"s3://{model_bundle_s3_url['bucket']}/{model_bundle_s3_url['key']}"
555-
556-
# Make bundle upload
557-
558-
requests.put(s3_path, data=serialized_bundle)
583+
if request_schema is not None and response_schema is not None:
584+
model_definitions = get_model_definitions(
585+
request_schema=request_schema,
586+
response_schema=response_schema,
587+
)
588+
model_definitions_encoded = json.dumps(model_definitions).encode()
589+
schema_location = self._upload_data(model_definitions_encoded)
559590

560591
payload = dict(
561592
packaging_type="cloudpickle",
@@ -564,6 +595,7 @@ def create_model_bundle( # pylint: disable=too-many-statements
564595
bundle_metadata=bundle_metadata,
565596
requirements=requirements,
566597
env_params=env_params,
598+
schema_location=schema_location,
567599
)
568600

569601
_add_app_config_to_bundle_create_payload(payload, app_config)
@@ -581,7 +613,8 @@ def create_model_bundle( # pylint: disable=too-many-statements
581613
requirements=requirements,
582614
packaging_type=ModelBundlePackagingType("cloudpickle"),
583615
metadata=bundle_metadata,
584-
app_config=payload.get("app_config"),
616+
app_config=app_config,
617+
schema_location=schema_location,
585618
)
586619
create_model_bundle_request = CreateModelBundleRequest(**payload) # type: ignore
587620
api_instance.create_model_bundle_v1_model_bundles_post(

launch/pydantic_schemas.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from enum import Enum
2+
from typing import Any, Dict, Set, Type, Union
3+
4+
from pydantic import BaseModel
5+
from pydantic.schema import get_flat_models_from_models, model_process_schema
6+
7+
REF_PREFIX = "#/components/schemas/"
8+
9+
10+
def get_model_definitions(
11+
request_schema: Type[BaseModel], response_schema: Type[BaseModel]
12+
) -> Dict[str, Any]:
13+
"""
14+
Gets the model schemas in jsonschema format from a sequence of Pydantic BaseModels.
15+
"""
16+
flat_models = get_flat_models_from_models(
17+
[request_schema, response_schema]
18+
)
19+
model_name_map = {model: model.__name__ for model in flat_models}
20+
model_name_map.update(
21+
{request_schema: "RequestSchema", response_schema: "ResponseSchema"}
22+
)
23+
return get_model_definitions_from_flat_models(
24+
flat_models=flat_models, model_name_map=model_name_map
25+
)
26+
27+
28+
def get_model_definitions_from_flat_models(
29+
*,
30+
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
31+
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
32+
) -> Dict[str, Any]:
33+
"""
34+
Gets the model schemas in jsonschema format from a set of Pydantic BaseModels (or Enums).
35+
Inspired by https://github.com/tiangolo/fastapi/blob/99d8470a8e1cf76da8c5274e4e372630efc95736/fastapi/utils.py#L38
36+
37+
Args:
38+
flat_models (Set[Union[Type[BaseModel], Type[Enum]]]): The models.
39+
model_name_map (Dict[Union[Type[BaseModel], Type[Enum]], str]): The map from model to name.
40+
41+
Returns:
42+
Dict[str, Any]: OpenAPI-compatible schema of model definitions.
43+
"""
44+
definitions: Dict[str, Dict[str, Any]] = {}
45+
for model in flat_models:
46+
m_schema, m_definitions, _ = model_process_schema(
47+
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
48+
)
49+
definitions.update(m_definitions)
50+
model_name = model_name_map[model]
51+
if "description" in m_schema:
52+
m_schema["description"] = m_schema["description"].split("\f")[0]
53+
definitions[model_name] = m_schema
54+
return definitions

poetry.lock

Lines changed: 54 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pyyaml = ">=5.3.1,<7.0.0"
4747
typing-extensions = "^4.1.1"
4848
click = "^8.0.0"
4949
frozendict = "^2.3.4"
50+
pydantic = "^1.10.4"
5051
types-frozendict = "^2.0.9"
5152

5253
[tool.poetry.dev-dependencies]

0 commit comments

Comments
 (0)