20
20
21
21
22
22
class ObjectDetectionEval (nn .Module ):
23
- def forward (
24
- self , img : Tensor , target : Optional [Dict [str , Tensor ]] = None
25
- ) -> Tuple [Tensor , Optional [Dict [str , Tensor ]]]:
23
+ def forward (self , img : Tensor ) -> Tensor :
26
24
if not isinstance (img , Tensor ):
27
25
img = F .pil_to_tensor (img )
28
- return F .convert_image_dtype (img , torch .float ), target
26
+ return F .convert_image_dtype (img , torch .float )
29
27
30
28
31
29
class ImageClassificationEval (nn .Module ):
@@ -95,28 +93,22 @@ def __init__(
95
93
self ._interpolation = interpolation
96
94
self ._interpolation_target = interpolation_target
97
95
98
- def forward (self , img : Tensor , target : Optional [ Tensor ] = None ) -> Tuple [ Tensor , Optional [ Tensor ]] :
96
+ def forward (self , img : Tensor ) -> Tensor :
99
97
if isinstance (self ._size , list ):
100
98
img = F .resize (img , self ._size , interpolation = self ._interpolation )
101
99
if not isinstance (img , Tensor ):
102
100
img = F .pil_to_tensor (img )
103
101
img = F .convert_image_dtype (img , torch .float )
104
102
img = F .normalize (img , mean = self ._mean , std = self ._std )
105
- if target :
106
- if isinstance (self ._size , list ):
107
- target = F .resize (target , self ._size , interpolation = self ._interpolation_target )
108
- if not isinstance (target , Tensor ):
109
- target = F .pil_to_tensor (target )
110
- target = target .squeeze (0 ).to (torch .int64 )
111
- return img , target
103
+ return img
112
104
113
105
114
106
class OpticalFlowEval (nn .Module ):
115
- def forward (
116
- self , img1 : Tensor , img2 : Tensor , flow : Optional [ Tensor ] = None , valid_flow_mask : Optional [ Tensor ] = None
117
- ) -> Tuple [ Tensor , Tensor , Optional [ Tensor ], Optional [ Tensor ]]:
118
-
119
- img1 , img2 , flow , valid_flow_mask = self . _pil_or_numpy_to_tensor ( img1 , img2 , flow , valid_flow_mask )
107
+ def forward (self , img1 : Tensor , img2 : Tensor ) -> Tuple [ Tensor , Tensor ]:
108
+ if not isinstance ( img1 , Tensor ):
109
+ img1 = F . pil_to_tensor ( img1 )
110
+ if not isinstance ( img2 , Tensor ):
111
+ img2 = F . pil_to_tensor ( img2 )
120
112
121
113
img1 = F .convert_image_dtype (img1 , torch .float32 )
122
114
img2 = F .convert_image_dtype (img2 , torch .float32 )
@@ -128,19 +120,4 @@ def forward(
128
120
img1 = img1 .contiguous ()
129
121
img2 = img2 .contiguous ()
130
122
131
- return img1 , img2 , flow , valid_flow_mask
132
-
133
- def _pil_or_numpy_to_tensor (
134
- self , img1 : Tensor , img2 : Tensor , flow : Optional [Tensor ], valid_flow_mask : Optional [Tensor ]
135
- ) -> Tuple [Tensor , Tensor , Optional [Tensor ], Optional [Tensor ]]:
136
- if not isinstance (img1 , Tensor ):
137
- img1 = F .pil_to_tensor (img1 )
138
- if not isinstance (img2 , Tensor ):
139
- img2 = F .pil_to_tensor (img2 )
140
-
141
- if flow is not None and not isinstance (flow , Tensor ):
142
- flow = torch .from_numpy (flow )
143
- if valid_flow_mask is not None and not isinstance (valid_flow_mask , Tensor ):
144
- valid_flow_mask = torch .from_numpy (valid_flow_mask )
145
-
146
- return img1 , img2 , flow , valid_flow_mask
123
+ return img1 , img2
0 commit comments