diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index 84ba47a..755a8df 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -301,23 +301,29 @@ def validate_and_initialize_user_module(self): ) self.log_func_implementation_found_or_not(load_fn, MODEL_FN) if load_fn is not None: - self.load_extra_arg = self.function_extra_arg(self.load, load_fn) + self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn) self.load = load_fn self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) if preprocess_fn is not None: - self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn) + self.preprocess_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.preprocess, preprocess_fn + ) self.preprocess = preprocess_fn self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) if predict_fn is not None: - self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn) + self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn) self.predict = predict_fn self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) if postprocess_fn is not None: - self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn) + self.postprocess_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.postprocess, postprocess_fn + ) self.postprocess = postprocess_fn self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) if transform_fn is not None: - self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn) + self.transform_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.transform_fn, transform_fn + ) self.transform_fn = transform_fn else: logger.info( @@ -342,8 +348,15 @@ def function_extra_arg(self, default_func, func): 1. the handle function takes context 2. the handle function does not take context """ - num_default_func_input = len(signature(default_func).parameters) - num_func_input = len(signature(func).parameters) + default_params = signature(default_func).parameters + func_params = signature(func).parameters + + if "self" in default_params: + num_default_func_input = len(default_params) - 1 + else: + num_default_func_input = len(default_params) + + num_func_input = len(func_params) if num_default_func_input == num_func_input: # function takes context extra_args = [self.context] diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py index a8b5b71..1790e40 100644 --- a/tests/unit/test_handler_service_with_context.py +++ b/tests/unit/test_handler_service_with_context.py @@ -166,3 +166,23 @@ def test_validate_and_initialize_user_module_transform_fn(): inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) == "output dummy" ) + + +def test_validate_and_initialize_user_module_transform_fn_race_condition(): + os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" + inference_handler = handler_service.HuggingFaceHandlerService() + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") + CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") + + # Similuate 2 threads bypassing check in handle() - calling initialize twice + inference_handler.initialize(CONTEXT) + inference_handler.initialize(CONTEXT) + + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] + assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" + assert ( + inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) + == "output dummy" + ) diff --git a/tests/unit/test_handler_service_without_context.py b/tests/unit/test_handler_service_without_context.py index 37c37e7..40fd4c3 100644 --- a/tests/unit/test_handler_service_without_context.py +++ b/tests/unit/test_handler_service_without_context.py @@ -154,3 +154,20 @@ def test_validate_and_initialize_user_module_transform_fn(): assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] assert inference_handler.load({}) == "Loading inference_tranform_fn.py" assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" + + +def test_validate_and_initialize_user_module_transform_fn_race_condition(): + os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" + inference_handler = handler_service.HuggingFaceHandlerService() + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context") + CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") + + # Similuate 2 threads bypassing check in handle() - calling initialize twice + inference_handler.initialize(CONTEXT) + inference_handler.initialize(CONTEXT) + + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] + assert inference_handler.load({}) == "Loading inference_tranform_fn.py" + assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"