Skip to content

Commit 6f78bac

Browse files
committed
split uniform sample
1 parent 8f123b4 commit 6f78bac

File tree

8 files changed

+254
-251
lines changed

8 files changed

+254
-251
lines changed

configs/recognition/mvit/mvit-base-p244_u32_sthv2-rgb.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@
3232
file_client_args = dict(io_backend='disk')
3333
train_pipeline = [
3434
dict(type='DecordInit', **file_client_args),
35-
dict(
36-
type='UniformSampleFrames',
37-
clip_len=32,
38-
out_of_bound_opt='repeat_frame'),
35+
dict(type='UniformSample', clip_len=32),
3936
dict(type='DecordDecode'),
4037
dict(type='Resize', scale=(-1, 256)),
4138
dict(type='RandomResizedCrop'),
@@ -51,11 +48,7 @@
5148
]
5249
val_pipeline = [
5350
dict(type='DecordInit', **file_client_args),
54-
dict(
55-
type='UniformSampleFrames',
56-
clip_len=32,
57-
out_of_bound_opt='repeat_frame',
58-
test_mode=True),
51+
dict(type='UniformSample', clip_len=32, test_mode=True),
5952
dict(type='DecordDecode'),
6053
dict(type='Resize', scale=(-1, 256)),
6154
dict(type='CenterCrop', crop_size=224),
@@ -64,11 +57,7 @@
6457
]
6558
test_pipeline = [
6659
dict(type='DecordInit', **file_client_args),
67-
dict(
68-
type='UniformSampleFrames',
69-
clip_len=32,
70-
out_of_bound_opt='repeat_frame',
71-
test_mode=True),
60+
dict(type='UniformSample', clip_len=32, test_mode=True),
7261
dict(type='DecordDecode'),
7362
dict(type='Resize', scale=(-1, 224)),
7463
dict(type='ThreeCrop', crop_size=224),

configs/recognition/mvit/mvit-large-p244_u40_sthv2-rgb.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434
file_client_args = dict(io_backend='disk')
3535
train_pipeline = [
3636
dict(type='DecordInit', **file_client_args),
37-
dict(
38-
type='UniformSampleFrames',
39-
clip_len=40,
40-
out_of_bound_opt='repeat_frame'),
37+
dict(type='UniformSample', clip_len=40),
4138
dict(type='DecordDecode'),
4239
dict(type='Resize', scale=(-1, 256)),
4340
dict(type='RandomResizedCrop'),
@@ -53,11 +50,7 @@
5350
]
5451
val_pipeline = [
5552
dict(type='DecordInit', **file_client_args),
56-
dict(
57-
type='UniformSampleFrames',
58-
clip_len=40,
59-
out_of_bound_opt='repeat_frame',
60-
test_mode=True),
53+
dict(type='UniformSample', clip_len=40, test_mode=True),
6154
dict(type='DecordDecode'),
6255
dict(type='Resize', scale=(-1, 256)),
6356
dict(type='CenterCrop', crop_size=224),
@@ -66,11 +59,7 @@
6659
]
6760
test_pipeline = [
6861
dict(type='DecordInit', **file_client_args),
69-
dict(
70-
type='UniformSampleFrames',
71-
clip_len=40,
72-
out_of_bound_opt='repeat_frame',
73-
test_mode=True),
62+
dict(type='UniformSample', clip_len=40, test_mode=True),
7463
dict(type='DecordDecode'),
7564
dict(type='Resize', scale=(-1, 224)),
7665
dict(type='ThreeCrop', crop_size=224),

configs/recognition/mvit/mvit-small-p244_u16_sthv2-rgb.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
file_client_args = dict(io_backend='disk')
1616
train_pipeline = [
1717
dict(type='DecordInit', **file_client_args),
18-
dict(
19-
type='UniformSampleFrames',
20-
clip_len=16,
21-
out_of_bound_opt='repeat_frame'),
18+
dict(type='UniformSample', clip_len=16),
2219
dict(type='DecordDecode'),
2320
dict(type='Resize', scale=(-1, 256)),
2421
dict(type='RandomResizedCrop'),
@@ -34,11 +31,7 @@
3431
]
3532
val_pipeline = [
3633
dict(type='DecordInit', **file_client_args),
37-
dict(
38-
type='UniformSampleFrames',
39-
clip_len=16,
40-
out_of_bound_opt='repeat_frame',
41-
test_mode=True),
34+
dict(type='UniformSample', clip_len=16, test_mode=True),
4235
dict(type='DecordDecode'),
4336
dict(type='Resize', scale=(-1, 256)),
4437
dict(type='CenterCrop', crop_size=224),
@@ -47,11 +40,7 @@
4740
]
4841
test_pipeline = [
4942
dict(type='DecordInit', **file_client_args),
50-
dict(
51-
type='UniformSampleFrames',
52-
clip_len=16,
53-
out_of_bound_opt='repeat_frame',
54-
test_mode=True),
43+
dict(type='UniformSample', clip_len=16, test_mode=True),
5544
dict(type='DecordDecode'),
5645
dict(type='Resize', scale=(-1, 224)),
5746
dict(type='ThreeCrop', crop_size=224),

mmaction/datasets/transforms/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
LoadProposals, OpenCVDecode, OpenCVInit, PIMSDecode,
1111
PIMSInit, PyAVDecode, PyAVDecodeMotionVector, PyAVInit,
1212
RawFrameDecode, SampleAVAFrames, SampleFrames,
13-
UniformSampleFrames, UntrimmedSampleFrames)
13+
UniformSample, UntrimmedSampleFrames)
1414
from .pose_loading import (GeneratePoseTarget, LoadKineticsPose,
15-
PaddingWithLoop, PoseDecode)
15+
PaddingWithLoop, PoseDecode, UniformSampleFrames)
1616
from .processing import (AudioAmplify, CenterCrop, ColorJitter, Flip, Fuse,
1717
MelSpectrogram, MultiScaleCrop, PoseCompact,
1818
RandomCrop, RandomRescale, RandomResizedCrop, Resize,
@@ -30,9 +30,9 @@
3030
'AudioAmplify', 'MelSpectrogram', 'AudioDecode', 'FormatAudioShape',
3131
'LoadAudioFeature', 'AudioFeatureSelector', 'AudioDecodeInit',
3232
'ImageDecode', 'BuildPseudoClip', 'RandomRescale', 'PIMSDecode',
33-
'PyAVDecodeMotionVector', 'UniformSampleFrames', 'PoseDecode',
34-
'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit', 'FormatGCNInput',
35-
'PaddingWithLoop', 'ArrayDecode', 'JointToBone', 'PackActionInputs',
36-
'PackLocalizationInputs', 'ImgAug', 'TorchVisionWrapper',
37-
'PytorchVideoWrapper', 'PoseCompact'
33+
'PyAVDecodeMotionVector', 'UniformSample', 'UniformSampleFrames',
34+
'PoseDecode', 'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit',
35+
'FormatGCNInput', 'PaddingWithLoop', 'ArrayDecode', 'JointToBone',
36+
'PackActionInputs', 'PackLocalizationInputs', 'ImgAug',
37+
'TorchVisionWrapper', 'PytorchVideoWrapper', 'PoseCompact'
3838
]

mmaction/datasets/transforms/loading.py

Lines changed: 15 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,15 @@ def __repr__(self):
266266

267267

268268
@TRANSFORMS.register_module()
269-
class UniformSampleFrames(BaseTransform):
270-
"""Uniformly sample frames from the video.
269+
class UniformSample(BaseTransform):
270+
"""Uniformly sample frames from the video. Currently used for Something-
271+
Something V2 dataset. Modified from
272+
https://github.com/facebookresearch/SlowFast/blob/64a
273+
bcc90ccfdcbb11cf91d6e525bed60e92a8796/slowfast/datasets/ssv2.py#L159.
271274
272275
To sample an n-frame clip from the video. UniformSampleFrames basically
273276
divides the video into n segments of equal length and randomly samples one
274-
frame from each segment. To make the testing results reproducible, a
275-
random seed is set during testing, to make the sampling results
276-
deterministic.
277+
frame from each segment.
277278
278279
Required keys:
279280
@@ -292,113 +293,23 @@ class UniformSampleFrames(BaseTransform):
292293
num_clips (int): Number of clips to be sampled. Default: 1.
293294
test_mode (bool): Store True when building test or validation dataset.
294295
Default: False.
295-
seed (int): The random seed used during test time. Default: 255.
296-
out_of_bound_opt (str): The way to deal with out of bounds frame
297-
indexes. Available options are 'loop', 'repeat_frame'.
298-
Default: 'loop'.
299296
"""
300297

301298
def __init__(self,
302299
clip_len: int,
303300
num_clips: int = 1,
304-
test_mode: bool = False,
305-
seed: int = 255,
306-
out_of_bound_opt: str = 'loop') -> None:
301+
test_mode: bool = False) -> None:
307302

308303
self.clip_len = clip_len
309304
self.num_clips = num_clips
310305
self.test_mode = test_mode
311-
self.seed = seed
312-
self.out_of_bound_opt = out_of_bound_opt
313-
assert self.out_of_bound_opt in ['loop', 'repeat_frame']
314-
315-
def _get_train_clips(self, num_frames: int):
316-
"""Uniformly sample indices for training clips.
317-
318-
Args:
319-
num_frames (int): The number of frames.
320-
"""
321-
322-
assert self.num_clips == 1
323-
if num_frames < self.clip_len:
324-
start = np.random.randint(0, num_frames)
325-
inds = np.arange(start, start + self.clip_len)
326-
elif self.clip_len <= num_frames < 2 * self.clip_len:
327-
basic = np.arange(self.clip_len)
328-
inds = np.random.choice(
329-
self.clip_len + 1, num_frames - self.clip_len, replace=False)
330-
offset = np.zeros(self.clip_len + 1, dtype=np.int32)
331-
offset[inds] = 1
332-
offset = np.cumsum(offset)
333-
inds = basic + offset[:-1]
334-
else:
335-
bids = np.array([
336-
i * num_frames // self.clip_len
337-
for i in range(self.clip_len + 1)
338-
])
339-
bsize = np.diff(bids)
340-
bst = bids[:self.clip_len]
341-
offset = np.random.randint(bsize)
342-
inds = bst + offset
343-
return inds
344-
345-
def _get_test_clips(self, num_frames: int):
346-
"""Uniformly sample indices for testing clips.
347306

348-
Args:
349-
num_frames (int): The number of frames.
350-
"""
351-
352-
np.random.seed(self.seed)
353-
if num_frames < self.clip_len:
354-
# Then we use a simple strategy
355-
if num_frames < self.num_clips:
356-
start_inds = list(range(self.num_clips))
357-
else:
358-
start_inds = [
359-
i * num_frames // self.num_clips
360-
for i in range(self.num_clips)
361-
]
362-
inds = np.concatenate(
363-
[np.arange(i, i + self.clip_len) for i in start_inds])
364-
elif self.clip_len <= num_frames < self.clip_len * 2:
365-
all_inds = []
366-
for i in range(self.num_clips):
367-
basic = np.arange(self.clip_len)
368-
inds = np.random.choice(
369-
self.clip_len + 1,
370-
num_frames - self.clip_len,
371-
replace=False)
372-
offset = np.zeros(self.clip_len + 1, dtype=np.int32)
373-
offset[inds] = 1
374-
offset = np.cumsum(offset)
375-
inds = basic + offset[:-1]
376-
all_inds.append(inds)
377-
inds = np.concatenate(all_inds)
378-
else:
379-
bids = np.array([
380-
i * num_frames // self.clip_len
381-
for i in range(self.clip_len + 1)
382-
])
383-
bsize = np.diff(bids)
384-
bst = bids[:self.clip_len]
385-
all_inds = []
386-
for i in range(self.num_clips):
387-
offset = np.random.randint(bsize)
388-
all_inds.append(bst + offset)
389-
inds = np.concatenate(all_inds)
390-
return inds
391-
392-
def _get_repeat_sample_clips(self, num_frames: int) -> np.array:
393-
"""Repeat sample when video is shorter than clip_len Modified from
394-
https://github.com/facebookresearch/SlowFast/blob/64ab
395-
cc90ccfdcbb11cf91d6e525bed60e92a8796/slowfast/datasets/ssv2.py#L159.
396-
397-
When video frames is shorter than target clip len, this strategy would
398-
repeat sample frame, rather than loop sample in 'loop' mode.
399-
In test mode, this strategy would sample the middle frame of each
400-
segment, rather than set a random seed, and therefore only support
401-
sample 1 clip.
307+
def _get_sample_clips(self, num_frames: int) -> np.array:
308+
"""When video frames is shorter than target clip len, this strategy
309+
would repeat sample frame, rather than loop sample in 'loop' mode. In
310+
test mode, this strategy would sample the middle frame of each segment,
311+
rather than set a random seed, and therefore only support sample 1
312+
clip.
402313
403314
Args:
404315
num_frames (int): Total number of frame in the video.
@@ -421,17 +332,7 @@ def _get_repeat_sample_clips(self, num_frames: int) -> np.array:
421332
def transform(self, results: dict):
422333
num_frames = results['total_frames']
423334

424-
if self.out_of_bound_opt == 'loop':
425-
if self.test_mode:
426-
inds = self._get_test_clips(num_frames)
427-
else:
428-
inds = self._get_train_clips(num_frames)
429-
inds = np.mod(inds, num_frames)
430-
elif self.out_of_bound_opt == 'repeat_frame':
431-
inds = self._get_repeat_sample_clips(num_frames)
432-
else:
433-
raise ValueError('Illegal out_of_bound option.')
434-
335+
inds = self._get_sample_clips(num_frames)
435336
start_index = results['start_index']
436337
inds = inds + start_index
437338

@@ -445,8 +346,7 @@ def __repr__(self):
445346
repr_str = (f'{self.__class__.__name__}('
446347
f'clip_len={self.clip_len}, '
447348
f'num_clips={self.num_clips}, '
448-
f'test_mode={self.test_mode}, '
449-
f'seed={self.seed})')
349+
f'test_mode={self.test_mode}')
450350
return repr_str
451351

452352

0 commit comments

Comments
 (0)