Skip to content

Custom Inference Code - model_fn() takes more positional argument #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dil-bhantos opened this issue Jun 17, 2024 · 0 comments
Open

Comments

@dil-bhantos
Copy link
Contributor

dil-bhantos commented Jun 17, 2024

Hello Everyone,

I have been testing this toolkit, trying to do some custom stuff when I got this error message.

Issue

1717449617113,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"

I was following the readme and overloaded the model_fn(model_dir). I had no idea, that this function could get multiple inputs, looking into the original handler implementation I figured that there might be a

My inference code

from transformers import AutoTokenizer, MistralForCausalLM, BitsAndBytesConfig
import torch

def model_fn(model_dir):

    model = MistralForCausalLM.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    return model, tokenizer

def predict_fn(data, model_and_tokenizer):

    model, tokenizer = model_and_tokenizer
    sentences = data.pop("inputs", data)
    parameters = data.pop("parameters", None)

    inputs = tokenizer(sentences, return_tensors="pt")

    if parameters is not None:
        outputs = model.generate(**inputs, **parameters)
    else:
        outputs = model.generate(**inputs)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Logs

1717449608604,"2024-06-03T21:20:08,521   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:55126 ""POST   /invocations HTTP/1.1"" 400 3"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
1717449612107,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"
1717449614111,"2024-06-03T21:20:12,046   [INFO ] pool-2-thread-6 ACCESS_LOG - /169.254.178.2:58518 ""GET   /ping HTTP/1.1"" 200 0"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response   time: 1"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:58518 ""POST   /invocations HTTP/1.1"" 400 2"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
**1717449617113,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"**

Solution

I was able to fix it by adding a new parameter to the function - def model_fn(model_dir, temp=None):.
It was a bit confusing, as the documentation reads as there was only 1 argument.
Is this general that the model_fn takes 2 arguments or it just happened in my particular case?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant