Skip to content

fix: correctly handle method parameter counting in function_extra_arg #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/sagemaker_huggingface_inference_toolkit/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,23 +301,29 @@ def validate_and_initialize_user_module(self):
)
self.log_func_implementation_found_or_not(load_fn, MODEL_FN)
if load_fn is not None:
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
self.load_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.load, load_fn)
self.load = load_fn
self.log_func_implementation_found_or_not(preprocess_fn, INPUT_FN)
if preprocess_fn is not None:
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
self.preprocess_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.preprocess, preprocess_fn
)
self.preprocess = preprocess_fn
self.log_func_implementation_found_or_not(predict_fn, PREDICT_FN)
if predict_fn is not None:
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
self.predict_extra_arg = self.function_extra_arg(HuggingFaceHandlerService.predict, predict_fn)
self.predict = predict_fn
self.log_func_implementation_found_or_not(postprocess_fn, OUTPUT_FN)
if postprocess_fn is not None:
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
self.postprocess_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.postprocess, postprocess_fn
)
self.postprocess = postprocess_fn
self.log_func_implementation_found_or_not(transform_fn, TRANSFORM_FN)
if transform_fn is not None:
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
self.transform_extra_arg = self.function_extra_arg(
HuggingFaceHandlerService.transform_fn, transform_fn
)
self.transform_fn = transform_fn
else:
logger.info(
Expand All @@ -342,8 +348,15 @@ def function_extra_arg(self, default_func, func):
1. the handle function takes context
2. the handle function does not take context
"""
num_default_func_input = len(signature(default_func).parameters)
num_func_input = len(signature(func).parameters)
default_params = signature(default_func).parameters
func_params = signature(func).parameters

if "self" in default_params:
num_default_func_input = len(default_params) - 1
else:
num_default_func_input = len(default_params)

num_func_input = len(func_params)
if num_default_func_input == num_func_input:
# function takes context
extra_args = [self.context]
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_handler_service_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,23 @@ def test_validate_and_initialize_user_module_transform_fn():
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
== "output dummy"
)


def test_validate_and_initialize_user_module_transform_fn_race_condition():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

# Similuate 2 threads bypassing check in handle() - calling initialize twice
inference_handler.initialize(CONTEXT)
inference_handler.initialize(CONTEXT)
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why there are two threads calling initialize twice in a single python process?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it is in the handle() method -

specifically this block:

 if not self.initialized:
                if self.attempted_init:
                    logger.warn(
                        "Model is not initialized, will try to load model again.\n"
                        "Please consider increase wait time for model loading.\n"
                    )
                self.initialize(context)

The test just assumes the fail condition already occurred

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From some personal testing I was able to see that that model gets loaded and later on fails when attempting to load again:

2025-03-25T18:41:48,682 [INFO ] W-9000-model-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Model model loaded io_fd=7a3c6bfffe7cf36a-0000007c-00000000-9116165d8c9c5504-f2c269a2

...

2025-03-25T18:42:34,468 [INFO ] W-model-1-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After testing with fix, did not see such error


CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
assert (
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
== "output dummy"
)
17 changes: 17 additions & 0 deletions tests/unit/test_handler_service_without_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,20 @@ def test_validate_and_initialize_user_module_transform_fn():
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"


def test_validate_and_initialize_user_module_transform_fn_race_condition():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

# Similuate 2 threads bypassing check in handle() - calling initialize twice
inference_handler.initialize(CONTEXT)
inference_handler.initialize(CONTEXT)

CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}) == "Loading inference_tranform_fn.py"
assert inference_handler.transform_fn("model", "dummy", "application/json", "application/json") == "output dummy"