Skip to content

Commit 931275a

Browse files
authored
fix: convert to str for torch.load (#2277)
* fix: convert to str for torch.load * fix: map to cpu * fix: skip the test if < 1.5.0 * revert: remove pull_request
1 parent ed5e1dd commit 931275a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/ignite/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
from collections import namedtuple
5+
from distutils.version import LooseVersion
56

67
import pytest
78
import torch
@@ -244,6 +245,7 @@ def test_smoke__utils():
244245
from ignite._utils import apply_to_tensor, apply_to_type, convert_tensor, to_onehot # noqa: F401
245246

246247

248+
@pytest.mark.skipif(LooseVersion(torch.__version__) < LooseVersion("1.5.0"), reason="Skip if < 1.5.0")
247249
def test_hash_checkpoint(tmp_path):
248250
# download lightweight model
249251
from torchvision.models import squeezenet1_0
@@ -253,7 +255,7 @@ def test_hash_checkpoint(tmp_path):
253255
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", f"{tmp_path}/squeezenet1_0.pt",
254256
)
255257
hash_checkpoint_path, sha_hash = hash_checkpoint(f"{tmp_path}/squeezenet1_0.pt", str(tmp_path))
256-
model.load_state_dict(torch.load(hash_checkpoint_path), True)
258+
model.load_state_dict(torch.load(str(hash_checkpoint_path), "cpu"), True)
257259
assert sha_hash[:8] == "b66bff10"
258260
assert hash_checkpoint_path.name == f"squeezenet1_0-{sha_hash[:8]}.pt"
259261

0 commit comments

Comments
 (0)