From a71781966b5f475cce30c0c8e3c2de8488cd520d Mon Sep 17 00:00:00 2001 From: Alexandre Duverger Date: Thu, 14 Mar 2024 16:18:50 -0400 Subject: [PATCH] Log if handler service is using default or custom functions implementations --- .../handler_service.py | 44 +++++++++++++++---- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index a924a45..e4ef6c6 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -34,6 +34,11 @@ ENABLE_MULTI_MODEL = os.getenv("SAGEMAKER_MULTI_MODEL", "false") == "true" PYTHON_PATH_ENV = "PYTHONPATH" +MODEL_FN = "model_fn" +INPUT_FN = "input_fn" +PREDICT_FN = "predict_fn" +OUTPUT_FN = "output_fn" +TRANSFORM_FN = "transform_fn" logger = logging.getLogger(__name__) @@ -272,35 +277,58 @@ def validate_and_initialize_user_module(self): """ user_module_name = self.environment.module_name if importlib.util.find_spec(user_module_name) is not None: + logger.info("Inference script implementation found at `{}`.".format(user_module_name)) user_module = importlib.import_module(user_module_name) - load_fn = getattr(user_module, "model_fn", None) - preprocess_fn = getattr(user_module, "input_fn", None) - predict_fn = getattr(user_module, "predict_fn", None) - postprocess_fn = getattr(user_module, "output_fn", None) - transform_fn = getattr(user_module, "transform_fn", None) + load_fn = getattr(user_module, MODEL_FN, None) + preprocess_fn = getattr(user_module, INPUT_FN, None) + predict_fn = getattr(user_module, PREDICT_FN, None) + postprocess_fn = getattr(user_module, OUTPUT_FN, None) + transform_fn = getattr(user_module, TRANSFORM_FN, None) if transform_fn and (preprocess_fn or predict_fn or postprocess_fn): raise ValueError( - "Cannot use transform_fn implementation in conjunction with " - "input_fn, predict_fn, and/or output_fn implementation" + "Cannot use {} implementation in conjunction with {}, {}, and/or {} implementation".format( + TRANSFORM_FN, INPUT_FN, PREDICT_FN, OUTPUT_FN + ) ) - + self.log_func_implementation_found_or_not(load_fn, MODEL_FN) if load_fn is not None: self.load_extra_arg = self.function_extra_arg(self.load, load_fn) self.load = load_fn + self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) if preprocess_fn is not None: self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn) self.preprocess = preprocess_fn + self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) if predict_fn is not None: self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn) self.predict = predict_fn + self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) if postprocess_fn is not None: self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn) self.postprocess = postprocess_fn + self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) if transform_fn is not None: self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn) self.transform_fn = transform_fn + else: + logger.info( + "No inference script implementation was found at `{}`. Default implementation of all functions will be used.".format( + user_module_name + ) + ) + + @staticmethod + def log_func_implementation_found_or_not(func, func_name): + if func is not None: + logger.info("`{}` implementation found. It will be used in place of the default one.".format(func_name)) + else: + logger.info( + "No `{}` implementation was found. The default one from the handler service will be used.".format( + func_name + ) + ) def function_extra_arg(self, default_func, func): """Helper to call the handler function which covers 2 cases: