From 6efa0b0545349104608f7ccf3da1cb5ece5177f2 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 26 Mar 2025 12:02:39 -0700 Subject: [PATCH 1/4] fix: correctly handle method parameter counting in function_extra_arg --- .../handler_service.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index 84ba47a..3c87653 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -301,23 +301,23 @@ def validate_and_initialize_user_module(self): ) self.log_func_implementation_found_or_not(load_fn, MODEL_FN) if load_fn is not None: - self.load_extra_arg = self.function_extra_arg(self.load, load_fn) + self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn) self.load = load_fn self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) if preprocess_fn is not None: - self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn) + self.preprocess_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.preprocess, preprocess_fn) self.preprocess = preprocess_fn self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) if predict_fn is not None: - self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn) + self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn) self.predict = predict_fn self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) if postprocess_fn is not None: - self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn) + self.postprocess_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.postprocess, postprocess_fn) self.postprocess = postprocess_fn self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) if transform_fn is not None: - self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn) + self.transform_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.transform_fn, transform_fn) self.transform_fn = transform_fn else: logger.info( @@ -342,8 +342,15 @@ def function_extra_arg(self, default_func, func): 1. the handle function takes context 2. the handle function does not take context """ - num_default_func_input = len(signature(default_func).parameters) - num_func_input = len(signature(func).parameters) + default_params = signature(default_func).parameters + func_params = signature(func).parameters + + if 'self' in default_params: + num_default_func_input = len(default_params) - 1 + else: + num_default_func_input = len(default_params) + + num_func_input = len(func_params) if num_default_func_input == num_func_input: # function takes context extra_args = [self.context] From 89be78518101e260bfb0e9781ec4976097844e3d Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 26 Mar 2025 12:06:39 -0700 Subject: [PATCH 2/4] format --- .../handler_service.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/sagemaker_huggingface_inference_toolkit/handler_service.py b/src/sagemaker_huggingface_inference_toolkit/handler_service.py index 3c87653..755a8df 100644 --- a/src/sagemaker_huggingface_inference_toolkit/handler_service.py +++ b/src/sagemaker_huggingface_inference_toolkit/handler_service.py @@ -305,7 +305,9 @@ def validate_and_initialize_user_module(self): self.load = load_fn self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN) if preprocess_fn is not None: - self.preprocess_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.preprocess, preprocess_fn) + self.preprocess_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.preprocess, preprocess_fn + ) self.preprocess = preprocess_fn self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN) if predict_fn is not None: @@ -313,11 +315,15 @@ def validate_and_initialize_user_module(self): self.predict = predict_fn self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN) if postprocess_fn is not None: - self.postprocess_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.postprocess, postprocess_fn) + self.postprocess_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.postprocess, postprocess_fn + ) self.postprocess = postprocess_fn self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN) if transform_fn is not None: - self.transform_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.transform_fn, transform_fn) + self.transform_extra_arg = self.function_extra_arg( + HuggingFaceHandlerService.transform_fn, transform_fn + ) self.transform_fn = transform_fn else: logger.info( @@ -345,7 +351,7 @@ def function_extra_arg(self, default_func, func): default_params = signature(default_func).parameters func_params = signature(func).parameters - if 'self' in default_params: + if "self" in default_params: num_default_func_input = len(default_params) - 1 else: num_default_func_input = len(default_params) From d1db51fa37a7339ffd8a6f7b39655c4e460d445a Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 28 Mar 2025 10:20:14 -0700 Subject: [PATCH 3/4] add tests --- .../unit/test_handler_service_with_context.py | 20 +++++++++++++++++++ .../test_handler_service_without_context.py | 17 ++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py index a8b5b71..117a488 100644 --- a/tests/unit/test_handler_service_with_context.py +++ b/tests/unit/test_handler_service_with_context.py @@ -166,3 +166,23 @@ def test_validate_and_initialize_user_module_transform_fn(): inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) == "output dummy" ) + + +def test_validate_and_initialize_user_module_transform_fn(): + os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" + inference_handler = handler_service.HuggingFaceHandlerService() + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") + CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") + + # Similuate 2 threads bypassing check in handle() - calling initialize twice + inference_handler.initialize(CONTEXT) + inference_handler.initialize(CONTEXT) + + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] + assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py" + assert ( + inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT) + == "output dummy" + ) diff --git a/tests/unit/test_handler_service_without_context.py b/tests/unit/test_handler_service_without_context.py index 37c37e7..5acef1d 100644 --- a/tests/unit/test_handler_service_without_context.py +++ b/tests/unit/test_handler_service_without_context.py @@ -154,3 +154,20 @@ def test_validate_and_initialize_user_module_transform_fn(): assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] assert inference_handler.load({}) == "Loading inference_tranform_fn.py" assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" + + +def test_validate_and_initialize_user_module_transform_fn(): + os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" + inference_handler = handler_service.HuggingFaceHandlerService() + model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context") + CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4") + + # Similuate 2 threads bypassing check in handle() - calling initialize twice + inference_handler.initialize(CONTEXT) + inference_handler.initialize(CONTEXT) + + CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})] + CONTEXT.metrics = MetricsStore(1, MODEL) + assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0] + assert inference_handler.load({}) == "Loading inference_tranform_fn.py" + assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" From a204e71e7890062cf3981bb8031419c022811388 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Fri, 28 Mar 2025 10:23:03 -0700 Subject: [PATCH 4/4] fix test name --- tests/unit/test_handler_service_with_context.py | 2 +- tests/unit/test_handler_service_without_context.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_handler_service_with_context.py b/tests/unit/test_handler_service_with_context.py index 117a488..1790e40 100644 --- a/tests/unit/test_handler_service_with_context.py +++ b/tests/unit/test_handler_service_with_context.py @@ -168,7 +168,7 @@ def test_validate_and_initialize_user_module_transform_fn(): ) -def test_validate_and_initialize_user_module_transform_fn(): +def test_validate_and_initialize_user_module_transform_fn_race_condition(): os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" inference_handler = handler_service.HuggingFaceHandlerService() model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context") diff --git a/tests/unit/test_handler_service_without_context.py b/tests/unit/test_handler_service_without_context.py index 5acef1d..40fd4c3 100644 --- a/tests/unit/test_handler_service_without_context.py +++ b/tests/unit/test_handler_service_without_context.py @@ -156,7 +156,7 @@ def test_validate_and_initialize_user_module_transform_fn(): assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy" -def test_validate_and_initialize_user_module_transform_fn(): +def test_validate_and_initialize_user_module_transform_fn_race_condition(): os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py" inference_handler = handler_service.HuggingFaceHandlerService() model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")