Skip to content

Log if handler service is using default or custom functions implementation #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions src/sagemaker_huggingface_inference_toolkit/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@

ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true"
PYTHON_PATH_ENV = "PYTHONPATH"
MODEL_FN = "model_fn"
INPUT_FN = "input_fn"
PREDICT_FN = "predict_fn"
OUTPUT_FN = "output_fn"
TRANSFORM_FN = "transform_fn"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -272,35 +277,58 @@ def validate_and_initialize_user_module(self):
"""
user_module_name = self.environment.module_name
if importlib.util.find_spec(user_module_name) is not None:
logger.info("Inference script implementation found at `{}`.".format(user_module_name))
user_module = importlib.import_module(user_module_name)

load_fn = getattr(user_module, "model_fn", None)
preprocess_fn = getattr(user_module, "input_fn", None)
predict_fn = getattr(user_module, "predict_fn", None)
postprocess_fn = getattr(user_module, "output_fn", None)
transform_fn = getattr(user_module, "transform_fn", None)
load_fn = getattr(user_module, MODEL_FN, None)
preprocess_fn = getattr(user_module, INPUT_FN, None)
predict_fn = getattr(user_module, PREDICT_FN, None)
postprocess_fn = getattr(user_module, OUTPUT_FN, None)
transform_fn = getattr(user_module, TRANSFORM_FN, None)

if transform_fn and (preprocess_fn or predict_fn or postprocess_fn):
raise ValueError(
"Cannot use transform_fn implementation in conjunction with "
"input_fn, predict_fn, and/or output_fn implementation"
"Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format(
TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN
)
)

self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
if load_fn is not None:
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
self.load = load_fn
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
if preprocess_fn is not None:
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
self.preprocess = preprocess_fn
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
if predict_fn is not None:
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
self.predict = predict_fn
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
if postprocess_fn is not None:
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
self.postprocess = postprocess_fn
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
if transform_fn is not None:
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
self.transform_fn = transform_fn
else:
logger.info(
"No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format(
user_module_name
)
)

@staticmethod
def log_func_implementation_found_or_not(func, func_name):
if func is not None:
logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name))
else:
logger.info(
"No `{}` implementation was found. The default one from the handler service will be used.".format(
func_name
)
)

def function_extra_arg(self, default_func, func):
"""Helper to call the handler function which covers 2 cases:
Expand Down