Skip to content

Improve prediction schema mismatch #249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/workloads/cortex/lib/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_request_handler_impl(self, api_name):
module_prefix, api["name"], api["request_handler_impl_key"]
)
except CortexException as e:
e.wrap("api " + api_name, "request_handler")
e.wrap("api " + api_name, "request_handler " + api["request_handler"])
raise

try:
Expand Down
47 changes: 29 additions & 18 deletions pkg/workloads/cortex/onnx_serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,44 +84,55 @@ def transform_to_numpy(input_pyobj, input_metadata):
target_dtype = ONNX_TO_NP_TYPE[input_metadata.type]
target_shape = input_metadata.shape

for idx, dim in enumerate(target_shape):
if dim is None:
target_shape[idx] = 1
try:
for idx, dim in enumerate(target_shape):
if dim is None:
target_shape[idx] = 1

if type(input_pyobj) is not np.ndarray:
np_arr = np.array(input_pyobj, dtype=target_dtype)
else:
np_arr = input_pyobj
np_arr = np_arr.reshape(target_shape)
return np_arr
if type(input_pyobj) is not np.ndarray:
np_arr = np.array(input_pyobj, dtype=target_dtype)
else:
np_arr = input_pyobj
np_arr = np_arr.reshape(target_shape)
return np_arr
except Exception as e:
raise UserException(str(e)) from e


def convert_to_onnx_input(sample, input_metadata_list):
sess = local_cache["sess"]

input_dict = {}
if len(input_metadata_list) == 1:
input_metadata = input_metadata_list[0]
if util.is_dict(sample):
if sample.get(input_metadata.name) is None:
raise ValueError("sample should be a dict containing key: " + input_metadata.name)
raise UserException('missing key "{}"'.format(input_metadata.name))
input_dict[input_metadata.name] = transform_to_numpy(
sample[input_metadata.name], input_metadata
)
else:
input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
try:
input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
except CortexException as e:
e.wrap("key {}".format(input_metadata.name))
raise
else:
for input_metadata in input_metadata_list:
if not util.is_dict(input_metadata):
expected_keys = [metadata.name for metadata in input_metadata_list]
raise ValueError(
"sample should be a dict containing keys: " + ", ".join(expected_keys)
raise UserException(
"expected sample to be a dictionary with keys {}".format(
", ".join('"' + key + '"' for key in expected_keys)
)
)

if sample.get(input_metadata.name) is None:
raise ValueError("sample should be a dict containing key: " + input_metadata.name)

input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
raise UserException('missing key "{}"'.format(input_metadata.name))
try:
input_dict[input_metadata.name] = transform_to_numpy(sample, input_metadata)
except CortexException as e:
e.wrap("key {}".format(input_metadata.name))
raise
logger.info(input_dict)
return input_dict


Expand Down
188 changes: 106 additions & 82 deletions pkg/workloads/cortex/tf_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def parse_response_proto_raw(response_proto):


def run_predict(sample):
ctx = local_cache["ctx"]
request_handler = local_cache.get("request_handler")

logger.info("sample: " + util.pp_str_flat(sample))
Expand All @@ -256,7 +257,15 @@ def run_predict(sample):
)
logger.info("pre_inference: " + util.pp_str_flat(prepared_sample))

validate_sample(sample)

if util.is_resource_ref(local_cache["api"]["model"]):
for column in local_cache["required_inputs"]:
column_type = ctx.get_inferred_column_type(column["name"])
prepared_sample[column["name"]] = util.upcast(
prepared_sample[column["name"]], column_type
)

transformed_sample = transform_sample(prepared_sample)
logger.info("transformed_sample: " + util.pp_str_flat(transformed_sample))

Expand All @@ -280,25 +289,30 @@ def run_predict(sample):
return result


def is_valid_sample(sample):
ctx = local_cache["ctx"]

for column in local_cache["required_inputs"]:
if column["name"] not in sample:
return False, "{} is missing".format(column["name"])

sample_val = sample[column["name"]]
column_type = ctx.get_inferred_column_type(column["name"])
is_valid = util.CORTEX_TYPE_TO_VALIDATOR[column_type](sample_val)

if not is_valid:
return (False, "{} should be a {}".format(column["name"], column_type))
def validate_sample(sample):
api = local_cache["api"]
if util.is_resource_ref(api["model"]):
ctx = local_cache["ctx"]
for column in local_cache["required_inputs"]:
if column["name"] not in sample:
raise UserException('missing key "{}"'.format(column["name"]))
sample_val = sample[column["name"]]
column_type = ctx.get_inferred_column_type(column["name"])
is_valid = util.CORTEX_TYPE_TO_VALIDATOR[column_type](sample_val)

return True, None
if not is_valid:
raise UserException(
'key "{}"'.format(column["name"]), "expected type " + column_type
)
else:
signature = extract_signature()
for input_name, metadata in signature.items():
if input_name not in sample:
raise UserException('missing key "{}"'.format(input_name))


def prediction_failed(sample, reason=None):
message = "prediction failed for sample: {}".format(utils.pp_str_flat(sample))
message = "prediction failed for sample: {}".format(util.pp_str_flat(sample))
if reason:
message += " ({})".format(reason)

Expand Down Expand Up @@ -337,15 +351,6 @@ def predict(deployment_name, api_name):
)

for i, sample in enumerate(payload["samples"]):
if util.is_resource_ref(api["model"]):
is_valid, reason = is_valid_sample(sample)
if not is_valid:
return prediction_failed(sample, reason)

for column in local_cache["required_inputs"]:
column_type = ctx.get_inferred_column_type(column["name"])
sample[column["name"]] = util.upcast(sample[column["name"]], column_type)

try:
result = run_predict(sample)
except CortexException as e:
Expand All @@ -363,20 +368,7 @@ def predict(deployment_name, api_name):
api["name"]
)
)

# Show signature def for external models (since we don't validate input)
schemaStr = ""
signature_def = local_cache["metadata"]["signatureDef"]
if (
not util.is_resource_ref(api["model"])
and signature_def.get("predict") is not None # Just to be safe
and signature_def["predict"].get("inputs") is not None # Just to be safe
):
schemaStr = "\n\nExpected shema:\n" + util.pp_str(
signature_def["predict"]["inputs"]
)

return prediction_failed(sample, str(e) + schemaStr)
return prediction_failed(sample, str(e))

predictions.append(result)

Expand All @@ -386,20 +378,39 @@ def predict(deployment_name, api_name):
return jsonify(response)


@app.route("/<app_name>/<api_name>/signature", methods=["GET"])
def get_signature(app_name, api_name):
def extract_signature():
signature_def = local_cache["metadata"]["signatureDef"]

response = {}
if signature_def.get("predict") is None or signature_def["predict"].get("inputs") is None:
return "unable to get signature for model", status.HTTP_404_NOT_FOUND
raise UserException("unable to find signature definition for model")

metadata = {}
for input_name, input_metadata in signature_def["predict"]["inputs"].items():
metadata[input_name] = {
"shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]],
"type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
}
return metadata


@app.route("/<app_name>/<api_name>/signature", methods=["GET"])
def get_signature(app_name, api_name):
ctx = local_cache["ctx"]
api = local_cache["api"]

try:
metadata = extract_signature()
except CortexException as e:
logger.error(str(e))
logger.exception(
"An error occurred, see `cortex logs -v api {}` for more details.".format(api["name"])
)
return str(e), HTTP_404_NOT_FOUND
except Exception as e:
logger.exception(
"An error occurred, see `cortex logs -v api {}` for more details.".format(api["name"])
)
return str(e), HTTP_404_NOT_FOUND

response = {"signature": metadata}
return jsonify(response)

Expand Down Expand Up @@ -435,52 +446,65 @@ def start(args):
local_cache["api"] = api
local_cache["ctx"] = ctx

if api.get("request_handler_impl_key") is not None:
local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"])

if not util.is_resource_ref(api["model"]):
if api.get("request_handler") is not None:
try:
if api.get("request_handler_impl_key") is not None:
local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"])

if not util.is_resource_ref(api["model"]):
if api.get("request_handler") is not None:
package.install_packages(ctx.python_packages, ctx.storage)
if not os.path.isdir(args.model_dir):
ctx.storage.download_and_unzip_external(api["model"], args.model_dir)
else:
package.install_packages(ctx.python_packages, ctx.storage)
if not os.path.isdir(args.model_dir):
ctx.storage.download_and_unzip_external(api["model"], args.model_dir)
else:
package.install_packages(ctx.python_packages, ctx.storage)
model_name = util.get_resource_ref(api["model"])
model = ctx.models[model_name]
estimator = ctx.estimators[model["estimator"]]

local_cache["model"] = model
local_cache["estimator"] = estimator
local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])]
local_cache["target_col_type"] = ctx.get_inferred_column_type(
util.get_resource_ref(model["target_column"])
)
model_name = util.get_resource_ref(api["model"])
model = ctx.models[model_name]
estimator = ctx.estimators[model["estimator"]]

local_cache["model"] = model
local_cache["estimator"] = estimator
local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])]
local_cache["target_col_type"] = ctx.get_inferred_column_type(
util.get_resource_ref(model["target_column"])
)

log_level = "DEBUG"
if ctx.environment is not None and ctx.environment.get("log_level") is not None:
log_level = ctx.environment["log_level"].get("tensorflow", "DEBUG")
tf_lib.set_logging_verbosity(log_level)
log_level = "DEBUG"
if ctx.environment is not None and ctx.environment.get("log_level") is not None:
log_level = ctx.environment["log_level"].get("tensorflow", "DEBUG")
tf_lib.set_logging_verbosity(log_level)

if not os.path.isdir(args.model_dir):
ctx.storage.download_and_unzip(model["key"], args.model_dir)
if not os.path.isdir(args.model_dir):
ctx.storage.download_and_unzip(model["key"], args.model_dir)

for column_name in ctx.extract_column_names([model["input"], model["target_column"]]):
if ctx.is_transformed_column(column_name):
trans_impl, _ = ctx.get_transformer_impl(column_name)
local_cache["trans_impls"][column_name] = trans_impl
transformed_column = ctx.transformed_columns[column_name]
for column_name in ctx.extract_column_names([model["input"], model["target_column"]]):
if ctx.is_transformed_column(column_name):
trans_impl, _ = ctx.get_transformer_impl(column_name)
local_cache["trans_impls"][column_name] = trans_impl
transformed_column = ctx.transformed_columns[column_name]

# cache aggregate values
for resource_name in util.extract_resource_refs(transformed_column["input"]):
if resource_name in ctx.aggregates:
ctx.get_obj(ctx.aggregates[resource_name]["key"])
# cache aggregate values
for resource_name in util.extract_resource_refs(transformed_column["input"]):
if resource_name in ctx.aggregates:
ctx.get_obj(ctx.aggregates[resource_name]["key"])

local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx)
local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx)

if util.is_dict(model["input"]) and model["input"].get("target_vocab") is not None:
local_cache["target_vocab_populated"] = ctx.populate_values(
model["input"]["target_vocab"], None, False
)
if util.is_dict(model["input"]) and model["input"].get("target_vocab") is not None:
local_cache["target_vocab_populated"] = ctx.populate_values(
model["input"]["target_vocab"], None, False
)
except CortexException as e:
e.wrap("error")
logger.error(str(e))
logger.exception(
"An error occurred, see `cortex logs -v api {}` for more details.".format(api["name"])
)
sys.exit(1)
except Exception as e:
logger.exception(
"An error occurred, see `cortex logs -v api {}` for more details.".format(api["name"])
)
sys.exit(1)

try:
validate_model_dir(args.model_dir)
Expand Down