Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions references/video_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ We assume the training and validation AVI videos are stored at `/data/kinectics4

Run the training on a single node with 8 GPUs:
```bash
torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=16 --cache-dataset --sync-bn --amp
torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --kinetics-version="400" --batch-size=16 --cache-dataset --sync-bn --amp
```

**Note:** all our models were trained on 8 nodes with 8 V100 GPUs each for a total of 64 GPUs. Expected training time for 64 GPUs is 24 hours, depending on the storage solution.
Expand All @@ -30,5 +30,13 @@ torchrun --nproc_per_node=8 train.py --data-path=/data/kinectics400 --train-dir=


```bash
python train.py --data-path=/data/kinectics400 --train-dir=train --val-dir=val --batch-size=8 --cache-dataset
python train.py --data-path=/data/kinectics400 --kinetics-version="400" --batch-size=8 --cache-dataset
```


### Additional Kinetics versions

Since the original release, additional versions of Kinetics dataset became available (Kinetics 600).
Our training scripts support these versions of dataset as well by setting the `--kinetics-version` parameter to `"600"`.

**Note:** training on Kinetics 600 requires a different set of hyperparameters for optimal performance. We do not provide Kinetics 600 pretrained models.
4 changes: 1 addition & 3 deletions references/video_classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torchvision.transforms import transforms
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
from transforms import ConvertBCHWtoCBHW


class VideoClassificationPresetTrain:
Expand All @@ -14,7 +14,6 @@ def __init__(
hflip_prob=0.5,
):
trans = [
ConvertBHWCtoBCHW(),
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
]
Expand All @@ -31,7 +30,6 @@ class VideoClassificationPresetEval:
def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
self.transforms = transforms.Compose(
[
ConvertBHWCtoBCHW(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can we simply remove this? It seems something is off here. The old Kinetics400 states

- video (Tensor[T, H, W, C]): the `T` video frames

whereas the new Kinetics states

- video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor

Was this intentional or did we botch this during the introduction of the new class? In any way, we need to adapt our deprecation warnings accordingly:

.. warning::
This class was deprecated in ``0.12`` and will be removed in ``0.14``. Please use
``Kinetics(..., num_classes='400')`` instead.

warnings.warn(
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
"Please use Kinetics(..., num_classes='400') instead."
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional -- the first iteration was convenient because that's the raw output from FFMPEG/pyav, but tv ops like transforms and conv operators work on Tensor[T, C, H, W] (as you've seen you'd have to have a transform to deal with this). The change makes sense, and I'll update the deprication warning, how does that sound?

transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
transforms.Normalize(mean=mean, std=std),
Expand Down
21 changes: 13 additions & 8 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def main(args):

# Data loading code
print("Loading data")
traindir = os.path.join(args.data_path, args.train_dir)
valdir = os.path.join(args.data_path, args.val_dir)
traindir = os.path.join(args.data_path, "train")
valdir = os.path.join(args.data_path, "val")

print("Loading training data")
st = time.time()
Expand All @@ -145,9 +145,11 @@ def main(args):
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset = torchvision.datasets.Kinetics400(
traindir,
dataset = torchvision.datasets.Kinetics(
args.data_path,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this is done, because we can now select the split through the dataset class rather than using a different root directory, correct? If yes, can we remove the traindir and valdir variables above and from the CLI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup; and done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier @bjuncek I checked on the servers, I believe there is a reason why these parameter are there. The datasets contain the training/val data at different preprocessed sizes. Example:

/datasets01/kinetics/070618/400$ ls 
README  list  list_cvt  scripts  test  train  train_avi-160p  train_avi-288p  val  val_avi-160p  val_avi-288p

Can we please restore this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed with @datumbox :

the different pre-processed size don't seem to be part of the original downloadable dataset. Looks more like an internal pre-processing that was made by someone.

IF someone was relying on this originally, they should have gotten deprecation warning telling them to use Kinetics, and they should have noticed that this wasn't supported anymore, which they should have told us.

This never happened, so we can assume this support for the different sizes isn't actually needed. If it is, we'll know soon enough.

So we're OK to remove that.

frames_per_clip=args.clip_len,
num_classes=args.kinetics_version,
split="train",
step_between_clips=1,
transform=transform_train,
frame_rate=15,
Expand Down Expand Up @@ -179,9 +181,11 @@ def main(args):
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset_test = torchvision.datasets.Kinetics400(
valdir,
dataset_test = torchvision.datasets.Kinetics(
args.data_path,
frames_per_clip=args.clip_len,
num_classes=args.kinetics_version,
split="val",
step_between_clips=1,
transform=transform_test,
frame_rate=15,
Expand Down Expand Up @@ -312,8 +316,9 @@ def parse_args():
parser = argparse.ArgumentParser(description="PyTorch Video Classification Training")

parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", type=str, help="dataset path")
parser.add_argument("--train-dir", default="train_avi-480p", type=str, help="name of train dir")
parser.add_argument("--val-dir", default="val_avi-480p", type=str, help="name of val dir")
parser.add_argument(
"--kinetics-version", default="400", type=str, choices=["400", "600"], help="Select kinetics version"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to update the commands shown at README.md.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)
parser.add_argument("--model", default="r2plus1d_18", type=str, help="model name")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip")
Expand Down
7 changes: 0 additions & 7 deletions references/video_classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@
import torch.nn as nn


class ConvertBHWCtoBCHW(nn.Module):
"""Convert tensor from (B, H, W, C) to (B, C, H, W)"""

def forward(self, vid: torch.Tensor) -> torch.Tensor:
return vid.permute(0, 3, 1, 2)


class ConvertBCHWtoCBHW(nn.Module):
"""Convert tensor from (B, C, H, W) to (C, B, H, W)"""

Expand Down
1 change: 1 addition & 0 deletions torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def __init__(
warnings.warn(
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
"Please use Kinetics(..., num_classes='400') instead."
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
)
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError(
Expand Down
11 changes: 4 additions & 7 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,11 @@ def _read_from_stream(
gc.collect()

if pts_unit == "sec":
# TODO: we should change all of this from ground up to simply take
# sec and convert to MS in C++
start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
else:
warnings.warn(
"The pts_unit 'pts' gives wrong results and will be removed in a "
+ "follow-up version. Please use pts_unit 'sec'."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this warning, which seems to do the right thing, can't we change the reference script to not use "pts"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really; or rather, it used to. The issue is that the backend (both pyav and VideoReader) consume data in pts which is unfortunate as they are not standardised per stream (e.g. audio and video have different timebases so pts=24 means different things). In a PR some time ago, we've tried to fix this by introducing sec as a default unit that would then be converted to appropriate stream. This caused an imprecision issue (bc thigs would be incorrectly rounded) which broke all datasets bc VideoClips creation requires fairly precise pts and such rounding erros cause major issues. We've initially fixed these by using pts as a default in VideoClips (which raises this warrning). This would be fine (i.e. I'd keep the warning and try to re-work the video clips) if all else was equal, but the PR mentioned above was reverted for internal issues, so this warrning is not really relevant anymore.

Having said that, I have a major PR with a VideoReading re-work in plan that will make this redundant in a way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds reasonable, but I have to little background to "sign this off". Who else from the team would be able to judge this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prabhat00155 most likely, as he was the one who originally implemented and then reverted the sync PR.
But we can also keep this open as I'm bound to finish the rework by the end of next week

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed with @datumbox

We'll put this warning back. There should only be 2 reasons to remove this warning:

  • fix the actual bug
  • completely forbid "pts".


frames = {}
should_buffer = True
Expand All @@ -176,9 +173,9 @@ def _read_from_stream(
# can't use regex directly because of some weird characters sometimes...
pos = extradata.find(b"DivX")
d = extradata[pos:]
o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
if o is None:
o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
if o is not None:
should_buffer = o.group(3) == b"p"
seek_offset = start_offset
Expand Down