Skip to content

Commit b31435a

Browse files
committed
add dependency to fix unit tests
1 parent e18c133 commit b31435a

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
# Hugging Face specific dependencies
5757
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]
58+
extras["diffusers"] = ["diffusers==0.23.1"]
5859

5960
# framework specific dependencies
6061
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
@@ -87,8 +88,7 @@
8788
"flake8>=3.8.3",
8889
]
8990

90-
extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"]
91-
91+
extras["dev"] = extras["transformers"] + extras["mms"] + extras["torch"] + extras["tensorflow"] + extras["diffusers"]
9292
setup(
9393
name="sagemaker-huggingface-inference-toolkit",
9494
version=VERSION,

src/sagemaker_huggingface_inference_toolkit/transformers_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,14 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
284284
kwargs["tokenizer"] = model_dir
285285

286286
if TRUST_REMOTE_CODE and os.environ.get("HF_MODEL_ID", None) is not None and device == 0:
287-
torch_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
288287
tokenizer = AutoTokenizer.from_pretrained(os.environ["HF_MODEL_ID"])
289288

290289
hf_pipeline = pipeline(
291290
task=task,
292291
model=os.environ["HF_MODEL_ID"],
293292
tokenizer=tokenizer,
294293
trust_remote_code=TRUST_REMOTE_CODE,
295-
model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
294+
model_kwargs={"device_map": "auto", "torch_dtype": "auto"},
296295
)
297296
elif is_diffusers_available() and task == "text-to-image":
298297
hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)

0 commit comments

Comments
 (0)