Skip to content

Commit c4c0ef9

Browse files
authored
make weights deepcopyable (#6883)
* make weights deepcopyable * add test * test enum member instead of whole enum
1 parent d95fbaf commit c4c0ef9

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/test_extended_models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import os
23

34
import pytest
@@ -59,6 +60,25 @@ def test_get_model_weights(name, weight):
5960
assert models.get_model_weights(name) == weight
6061

6162

63+
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
64+
@pytest.mark.parametrize(
65+
"name",
66+
[
67+
"resnet50",
68+
"retinanet_resnet50_fpn_v2",
69+
"raft_large",
70+
"quantized_resnet50",
71+
"lraspp_mobilenet_v3_large",
72+
"mvit_v1_b",
73+
],
74+
)
75+
def test_weights_copyable(copy_fn, name):
76+
model_weights = models.get_model_weights(name)
77+
for weights in list(model_weights):
78+
copied_weights = copy_fn(weights)
79+
assert copied_weights is weights
80+
81+
6282
@pytest.mark.parametrize(
6383
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
6484
)

torchvision/models/_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __getattr__(self, name):
7575
return object.__getattribute__(self.value, name)
7676
return super().__getattr__(name)
7777

78+
def __deepcopy__(self, memodict=None):
79+
return self
80+
7881

7982
def get_weight(name: str) -> WeightsEnum:
8083
"""

0 commit comments

Comments
 (0)