Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pkg/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,4 @@ var (
TelemetryURL = "https://telemetry.cortexlabs.dev"

MaxClassesPerRequest = 75 // cloudwatch.GeMetricData can get up to 100 metrics per request, avoid multiple requests and have room for other stats

DefaultTFServingSignatureKey = "predict"
)
9 changes: 1 addition & 8 deletions pkg/operator/api/userconfig/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"strings"

"github.com/aws/aws-sdk-go/service/s3"
"github.com/cortexlabs/cortex/pkg/consts"
"github.com/cortexlabs/cortex/pkg/lib/aws"
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
"github.com/cortexlabs/cortex/pkg/lib/errors"
Expand Down Expand Up @@ -115,7 +114,7 @@ var apiValidation = &cr.StructValidation{
{
StructField: "SignatureKey",
StringValidation: &cr.StringValidation{
Default: consts.DefaultTFServingSignatureKey,
Required: true,
},
},
},
Expand Down Expand Up @@ -277,12 +276,6 @@ func (api *API) Validate(projectFileMap map[string][]byte) error {
}
}

if api.ModelFormat == TensorFlowModelFormat && api.TFServing == nil {
api.TFServing = &TFServingOptions{
SignatureKey: consts.DefaultTFServingSignatureKey,
}
}

if api.ModelFormat != TensorFlowModelFormat && api.TFServing != nil {
return errors.Wrap(ErrorTFServingOptionsForTFOnly(api.ModelFormat), Identify(api))
}
Expand Down
79 changes: 50 additions & 29 deletions pkg/workloads/cortex/tf_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"ctx": None,
"stub": None,
"api": None,
"signature_key": None,
"signature": None,
"metadata": None,
"request_handler": None,
"class_set": set(),
Expand Down Expand Up @@ -116,7 +118,7 @@ def after_request(response):

def create_prediction_request(sample):
signature_def = local_cache["metadata"]["signatureDef"]
signature_key = local_cache["api"]["tf_serving"]["signature_key"]
signature_key = local_cache["signature_key"]
prediction_request = predict_pb2.PredictRequest()
prediction_request.model_spec.name = "model"
prediction_request.model_spec.signature_name = signature_key
Expand Down Expand Up @@ -207,9 +209,7 @@ def run_predict(sample, debug=False):


def validate_sample(sample):
signature = extract_signature(
local_cache["metadata"]["signatureDef"], local_cache["api"]["tf_serving"]["signature_key"]
)
signature = local_cache["signature"]
for input_name, _ in signature.items():
if input_name not in sample:
raise UserException('missing key "{}"'.format(input_name))
Expand Down Expand Up @@ -252,38 +252,57 @@ def predict(deployment_name, api_name):


def extract_signature(signature_def, signature_key):
if (
signature_def.get(signature_key) is None
or signature_def[signature_key].get("inputs") is None
):
raise UserException(
'unable to find "' + signature_key + "\" in model's signature definition"
)
logger.info("signature defs found in model: {}".format(signature_def))

available_keys = list(signature_def.keys())
if len(available_keys) == 0:
raise UserException("unable to find signature defs in model")

if signature_key is None:
if len(available_keys) == 1:
logger.info(
"signature_key was not configured by user, using signature key '{}' found in signature def map".format(
available_keys[0]
)
)
signature_key = available_keys[0]
else:
raise UserException(
"signature_key was not configured by user, please specify one the following keys '{}' found in signature def map".format(
"', '".join(available_keys)
)
)
else:
if signature_def.get(signature_key) is None:
raise UserException(
"signature_key '{}' was not found, please specify one the following keys '{}' found in signature def map".format(
signature_key, "', '".join(available_keys)
)
)

signature_def_val = signature_def.get(signature_key)

if signature_def_val.get("inputs") is None:
raise UserException("unable to find 'inputs' in signature def '{}'".format(signature_key))

metadata = {}
for input_name, input_metadata in signature_def[signature_key]["inputs"].items():
for input_name, input_metadata in signature_def_val["inputs"].items():
metadata[input_name] = {
"shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]],
"type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
}
return metadata
return signature_key, metadata


@app.route("/<app_name>/<api_name>/signature", methods=["GET"])
def get_signature(app_name, api_name):
ctx = local_cache["ctx"]
api = local_cache["api"]

try:
metadata = extract_signature(
local_cache["metadata"]["signatureDef"],
local_cache["api"]["tf_serving"]["signature_key"],
)
signature = local_cache["signature"]
except Exception as e:
logger.exception("failed to get signature")
return jsonify(error=str(e)), 404

response = {"signature": metadata}
response = {"signature": signature}
return jsonify(response)


Expand Down Expand Up @@ -385,14 +404,16 @@ def start(args):
sys.exit(1)

time.sleep(5)
logger.info(
"model_signature: {}".format(
extract_signature(
local_cache["metadata"]["signatureDef"],
local_cache["api"]["tf_serving"]["signature_key"],
)
)
)

signature_key = None
if api.get("tf_serving") is not None and api["tf_serving"].get("signature_key") is not None:
signature_key = api["tf_serving"]["signature_key"]

key, metadata = extract_signature(local_cache["metadata"]["signatureDef"], signature_key)

local_cache["signature_key"] = key
local_cache["signature"] = metadata
logger.info("model_signature: {}".format(local_cache["signature"]))
serve(app, listen="*:{}".format(args.port))


Expand Down