Skip to content

Custom Inference Code - model_fn() takes more positional argument #126

Open
@dil-bhantos

Description

@dil-bhantos

Hello Everyone,

I have been testing this toolkit, trying to do some custom stuff when I got this error message.

Issue

1717449617113,"2024-06-03T21:20:13,924 [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - mms.service.PredictionException: model_fn() takes 1 positional argument but 2 were given : 400"

I was following the readme and overloaded the model_fn(model_dir). I had no idea, that this function could get multiple inputs, looking into the original handler implementation I figured that there might be a

My inference code

from transformers import AutoTokenizer, MistralForCausalLM, BitsAndBytesConfig
import torch

def model_fn(model_dir):

    model = MistralForCausalLM.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    return model, tokenizer

def predict_fn(data, model_and_tokenizer):

    model, tokenizer = model_and_tokenizer
    sentences = data.pop("inputs", data)
    parameters = data.pop("parameters", None)

    inputs = tokenizer(sentences, return_tensors="pt")

    if parameters is not None:
        outputs = model.generate(**inputs, **parameters)
    else:
        outputs = model.generate(**inputs)

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

Logs

1717449608604,"2024-06-03T21:20:08,521   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:55126 ""POST   /invocations HTTP/1.1"" 400 3"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449608604,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
1717449612107,"2024-06-03T21:20:08,522   [INFO ] W-model-4-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"
1717449614111,"2024-06-03T21:20:12,046   [INFO ] pool-2-thread-6 ACCESS_LOG - /169.254.178.2:58518 ""GET   /ping HTTP/1.1"" 200 0"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model com.amazonaws.ml.mms.wlm.WorkerThread - Backend response   time: 1"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   Prediction error"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 243, in handle"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.initialize(context)"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 83, in initialize"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     self.model =   self.load(*([self.model_dir] + self.load_extra_arg))"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-9000-model ACCESS_LOG - /169.254.178.2:58518 ""POST   /invocations HTTP/1.1"" 400 2"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   TypeError: model_fn() takes 1 positional argument but 2 were given"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - During   handling of the above exception, another exception occurred:"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - "
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - Traceback   (most recent call last):"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/mms/service.py"",   line 108, in predict"
1717449614111,"2024-06-03T21:20:13,923   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     ret = self._entry_point(input_batch,   self.context)"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   File   ""/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py"",   line 267, in handle"
1717449614111,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -     raise PredictionException(str(e),   400)"
**1717449617113,"2024-06-03T21:20:13,924   [INFO ] W-model-2-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle -   mms.service.PredictionException: model_fn() takes 1 positional argument but 2   were given : 400"**

Solution

I was able to fix it by adding a new parameter to the function - def model_fn(model_dir, temp=None):.
It was a bit confusing, as the documentation reads as there was only 1 argument.
Is this general that the model_fn takes 2 arguments or it just happened in my particular case?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions