diff --git a/docs/source/models.rst b/docs/source/models.rst index 50af05360e4..39543cb8027 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor: convnext_large = models.convnext_large() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. -These can be constructed by passing ``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - squeezenet = models.squeezenet1_0(pretrained=True) - vgg16 = models.vgg16(pretrained=True) - densenet = models.densenet161(pretrained=True) - inception = models.inception_v3(pretrained=True) - googlenet = models.googlenet(pretrained=True) - shufflenet = models.shufflenet_v2_x1_0(pretrained=True) - mobilenet_v2 = models.mobilenet_v2(pretrained=True) - mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) - mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True) - resnext50_32x4d = models.resnext50_32x4d(pretrained=True) - wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) - mnasnet = models.mnasnet1_0(pretrained=True) - efficientnet_b0 = models.efficientnet_b0(pretrained=True) - efficientnet_b1 = models.efficientnet_b1(pretrained=True) - efficientnet_b2 = models.efficientnet_b2(pretrained=True) - efficientnet_b3 = models.efficientnet_b3(pretrained=True) - efficientnet_b4 = models.efficientnet_b4(pretrained=True) - efficientnet_b5 = models.efficientnet_b5(pretrained=True) - efficientnet_b6 = models.efficientnet_b6(pretrained=True) - efficientnet_b7 = models.efficientnet_b7(pretrained=True) - efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True) - efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True) - efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True) - regnet_y_400mf = models.regnet_y_400mf(pretrained=True) - regnet_y_800mf = models.regnet_y_800mf(pretrained=True) - regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) - regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True) - regnet_y_8gf = models.regnet_y_8gf(pretrained=True) - regnet_y_16gf = models.regnet_y_16gf(pretrained=True) - regnet_y_32gf = models.regnet_y_32gf(pretrained=True) - regnet_x_400mf = models.regnet_x_400mf(pretrained=True) - regnet_x_800mf = models.regnet_x_800mf(pretrained=True) - regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True) - regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True) - regnet_x_8gf = models.regnet_x_8gf(pretrained=True) - regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) - regnet_x_32gf = models.regnet_x_32gf(pretrained=True) - vit_b_16 = models.vit_b_16(pretrained=True) - vit_b_32 = models.vit_b_32(pretrained=True) - vit_l_16 = models.vit_l_16(pretrained=True) - vit_l_32 = models.vit_l_32(pretrained=True) - convnext_tiny = models.convnext_tiny(pretrained=True) - convnext_small = models.convnext_small(pretrained=True) - convnext_base = models.convnext_base(pretrained=True) - convnext_large = models.convnext_large(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code: .. code:: python import torchvision.models as models - model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) + model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True) model.eval() # run the model with quantized inputs and weights out = model(torch.rand(1, 3, 224, 224)) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 770610fb971..5149ebc541b 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs): def preprocess(img1_batch, img2_batch): img1_batch = F.resize(img1_batch, size=[520, 960]) img2_batch = F.resize(img2_batch, size=[520, 960]) - return transforms(img1_batch, img2_batch)[:2] + return transforms(img1_batch, img2_batch) img1_batch, img2_batch = preprocess(img1_batch, img2_batch) diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index a826a2523f2..7bb68617a17 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -146,7 +146,7 @@ def show(imgs): print(img.size()) tranforms = weights.transforms() -img, _ = tranforms(img) +img = tranforms(img) target = {} target["boxes"] = boxes target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 27fd97681c0..7f92d54ebdd 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -81,7 +81,7 @@ def show(imgs): weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -131,7 +131,7 @@ def show(imgs): model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch, _ = transforms(batch) +normalized_batch = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -272,7 +272,7 @@ def show(imgs): weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -397,7 +397,7 @@ def show(imgs): weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -person_float, _ = transforms(person_int) +person_float = transforms(person_int) model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() diff --git a/references/classification/README.md b/references/classification/README.md index 173fb454995..c274c997791 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -43,7 +43,7 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model ``` torchrun --nproc_per_node=8 train.py --model inception_v3\ - --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained + --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 ``` ### ResNet @@ -96,22 +96,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: ``` -torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\ - --val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\ - --val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\ - --val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\ - --val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\ - --val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\ - --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1 ``` diff --git a/references/classification/train.py b/references/classification/train.py index 569cf3009e7..eb8b56c1ad0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -15,12 +15,6 @@ from torchvision.transforms.functional import InterpolationMode -try: - from torchvision import prototype -except ImportError: - prototype = None - - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -154,18 +148,13 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if not args.prototype: + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + preprocessing = weights.transforms() + else: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) - else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.ImageClassificationEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation - ) dataset_test = torchvision.datasets.ImageFolder( valdir, @@ -191,10 +180,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -236,10 +221,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - else: - model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -446,12 +428,6 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") @@ -496,14 +472,6 @@ def get_args_parser(add_help=True): parser.add_argument( "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 111777a860b..c0e5af1dcfc 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -12,17 +12,7 @@ from train import train_one_epoch, evaluate, load_data -try: - from torchvision import prototype -except ImportError: - prototype = None - - def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -56,10 +46,7 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - if not args.prototype: - model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) - else: - model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): @@ -264,14 +251,6 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 7f573415c4c..27398d97234 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T from torchvision import models as M # Classification - model = M.mobilenet_v3_large(pretrained=False) + model = M.mobilenet_v3_large() print(store_model_weights(model, './class.pth')) # Quantized Classification - model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) + model = M.quantization.mobilenet_v3_large(quantize=False) model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False) + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn() print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: diff --git a/references/detection/README.md b/references/detection/README.md index 3695644138b..aec7c10e1b5 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -24,35 +24,35 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs. ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large 320 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### FCOS ResNet-50 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fcos_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### SSD300 VGG16 @@ -60,7 +60,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssd300_vgg16 --epochs 120\ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ - --weight-decay 0.0005 --data-augmentation ssd + --weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES ``` ### SSDlite320 MobileNetV3-Large @@ -68,7 +68,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ - --weight-decay 0.00004 --data-augmentation ssdlite + --weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` @@ -76,7 +76,7 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` @@ -84,5 +84,5 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\ - --lr-steps 36 43 --aspect-ratio-group-factor 3 + --lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` diff --git a/references/detection/train.py b/references/detection/train.py index 3909e6413d0..0e0a0d70fad 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -33,12 +33,6 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -50,14 +44,12 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - elif not args.prototype: - return presets.DetectionPresetEval() + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target=None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.ObjectDetectionEval() + return presets.DetectionPresetEval() def get_args_parser(add_help=True): @@ -132,25 +124,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") @@ -159,10 +138,6 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -204,12 +179,9 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - if not args.prototype: - model = torchvision.models.detection.__dict__[args.model]( - pretrained=args.pretrained, num_classes=num_classes, **kwargs - ) - else: - model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index a7620ce4be6..a7ac0223739 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -51,7 +51,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ### Evaluation ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 ``` This should give an epe of about 1.3822 on the clean pass and 2.7161 on the @@ -67,6 +67,6 @@ Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: You can also evaluate on Kitti train: ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 ``` diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 83952242eb9..1a50d1c617d 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -9,11 +9,6 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -try: - from torchvision import prototype -except ImportError: - prototype = None - def get_train_dataset(stage, dataset_root): if stage == "chairs": @@ -138,12 +133,10 @@ def inner_loop(blob): def evaluate(model, args): val_datasets = args.val_dataset or [] - if args.prototype: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.OpticalFlowEval() + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) # noqa: E731 else: preprocessing = OpticalFlowPresetEval() @@ -201,20 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) + args.test_only = args.train_dataset is None if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) - if args.prototype: - model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) - else: - model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) if args.distributed: model = model.to(args.local_rank) @@ -228,7 +215,7 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) - if args.train_dataset is None: + if args.test_only: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True @@ -356,8 +343,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" ) - # TODO: resume, pretrained, and weights should be in an exclusive arg group - parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + # TODO: resume and weights should be in an exclusive arg group parser.add_argument( "--num_flow_updates", @@ -376,13 +362,6 @@ def get_args_parser(add_help=True): required=True, ) - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") diff --git a/references/segmentation/README.md b/references/segmentation/README.md index e9b5391215a..2c7391c8380 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -14,30 +14,30 @@ You must modify the following flags: ## fcn_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## fcn_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ## lraspp_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5dc03945bd7..b4e55acd407 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,12 +11,6 @@ from torch import nn -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) @@ -35,14 +29,12 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif not args.prototype: - return presets.SegmentationPresetEval(base_size=520) + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target=None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.SemanticSegmentationEval(resize_size=520) + return presets.SegmentationPresetEval(base_size=520) def criterion(inputs, target): @@ -100,10 +92,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -135,16 +123,9 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - if not args.prototype: - model = torchvision.models.segmentation.__dict__[args.model]( - pretrained=args.pretrained, - num_classes=num_classes, - aux_loss=args.aux_loss, - ) - else: - model = prototype.models.segmentation.__dict__[args.model]( - weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss - ) + model = torchvision.models.segmentation.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -272,24 +253,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 04039c9a4f1..d24169e42dd 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -6,8 +6,8 @@ class VideoClassificationPresetTrain: def __init__( self, - resize_size, crop_size, + resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), hflip_prob=0.5, @@ -27,7 +27,7 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__(self, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): self.transforms = transforms.Compose( [ ConvertBHWCtoBCHW(), diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d36785ddf96..da7ef9fc607 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,11 +12,6 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from torchvision import prototype -except ImportError: - prototype = None - def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() @@ -96,10 +91,6 @@ def collate_fn(batch): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -120,7 +111,7 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) + transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -150,14 +141,11 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if not args.prototype: - transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + transform_test = weights.transforms() else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - transform_test = weights.transforms() - else: - transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -208,10 +196,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - else: - model = prototype.models.video.__dict__[args.model](weights=args.weights) + model = torchvision.models.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -352,24 +337,11 @@ def parse_args(): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 4bfe03d1ea0..a07b501e15b 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -3,12 +3,14 @@ import pytest import test_models as TM +import torch from torchvision import models from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface -run_if_test_with_prototype = pytest.mark.skipif( - os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1", + +run_if_test_with_extended = pytest.mark.skipif( + os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", ) @@ -76,7 +78,7 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) -@run_if_test_with_prototype +@run_if_test_with_extended def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"] defaults = { @@ -123,6 +125,63 @@ def test_schema_meta_validation(model_fn): assert not bad_names +@pytest.mark.parametrize( + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), +) +@run_if_test_with_extended +def test_transforms_jit(model_fn): + model_name = model_fn.__name__ + weights_enum = _get_model_weights(model_fn) + if len(weights_enum) == 0: + pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") + + defaults = { + "models": { + "input_shape": (1, 3, 224, 224), + }, + "detection": { + "input_shape": (3, 300, 300), + }, + "quantization": { + "input_shape": (1, 3, 224, 224), + }, + "segmentation": { + "input_shape": (1, 3, 520, 520), + }, + "video": { + "input_shape": (1, 4, 112, 112, 3), + }, + "optical_flow": { + "input_shape": (1, 3, 128, 128), + }, + } + module_name = model_fn.__module__.split(".")[-2] + + kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + x = torch.rand(input_shape) + if module_name == "optical_flow": + args = (x, x) + else: + args = (x,) + + problematic_weights = [] + for w in weights_enum: + transforms = w.transforms() + try: + TM._check_jit_scriptable(transforms, args) + except Exception: + problematic_weights.append(w) + + assert not problematic_weights + + # With this filter, every unexpected warning will be turned into an error @pytest.mark.filterwarnings("error") class TestHandleLegacyInterface: diff --git a/test/test_models.py b/test/test_models.py index 5bef9e24d9f..0d45d61df13 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,8 +133,7 @@ def get_export_import_copy(m): if eager_out is None: with torch.no_grad(), freeze_rng_state(): - if unwrapper: - eager_out = nn_module(*args) + eager_out = nn_module(*args) with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 4df533000f9..6ee5b98c673 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "AlexNet", diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 8d25e77eaa1..8774b9a1bc2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,7 +7,7 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -218,7 +218,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), + transforms=partial(ImageClassification, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, @@ -232,7 +232,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), + transforms=partial(ImageClassification, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, @@ -246,7 +246,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, @@ -260,7 +260,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index b0de4529902..2ffb29c54cb 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -280,7 +280,7 @@ def _densenet( class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7978856, @@ -294,7 +294,7 @@ class DenseNet121_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 28681000, @@ -308,7 +308,7 @@ class DenseNet161_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 14149480, @@ -322,7 +322,7 @@ class DenseNet169_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 20013928, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index b5a1df4502c..7d18fbe90a3 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -336,7 +336,7 @@ def forward(self, x): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 41755286, @@ -350,7 +350,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, @@ -364,7 +364,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 7948cd76ab2..27e54a565f2 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -646,7 +646,7 @@ def forward( class FCOS_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "FCOS", diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 522545293a0..2a554a6f56e 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -318,7 +318,7 @@ def forward(self, x): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_LEGACY = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, @@ -329,7 +329,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ) COCO_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 47b32984991..fb60ffcbb0a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -308,7 +308,7 @@ def __init__(self, in_channels, dim_reduced, num_classes): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "MaskRCNN", diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 048d68b52dc..49b9acf45e4 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -10,7 +10,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -588,7 +588,7 @@ def forward(self, images, targets=None): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "RetinaNet", diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index b907b3fccf8..c30919e621c 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops import boxes as box_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -28,7 +28,7 @@ class SSD300_VGG16_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSD", diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index dad28cfed13..93023337d11 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops.misc import Conv2dNormActivation -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet from .._api import WeightsEnum, Weights @@ -187,7 +187,7 @@ def _mobilenet_extractor( class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSDLite", diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 9665c169bbf..b9d3b9b30c9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,7 @@ from torchvision.ops import StochasticDepth from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -458,7 +458,7 @@ class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -475,7 +475,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -488,7 +488,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR ), meta={ **_COMMON_META_V1, @@ -507,7 +507,7 @@ class EfficientNet_B2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -524,7 +524,7 @@ class EfficientNet_B3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -541,7 +541,7 @@ class EfficientNet_B4_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -558,7 +558,7 @@ class EfficientNet_B5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -575,7 +575,7 @@ class EfficientNet_B6_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -592,7 +592,7 @@ class EfficientNet_B7_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -609,7 +609,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BILINEAR, @@ -629,7 +629,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BILINEAR, @@ -649,7 +649,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BICUBIC, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index e09e6788097..ced92571974 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tensor: class GoogLeNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 24d084b62d2..816fab45549 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -410,7 +410,7 @@ def forward(self, x: Tensor) -> Tensor: class Inception_V3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 287911edbec..578e77f7934 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -226,7 +226,7 @@ def _load_from_state_dict( class MNASNet0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2218512, @@ -245,7 +245,7 @@ class MNASNet0_75_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4383312, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 1e19db1a314..085049117ec 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -7,7 +7,7 @@ from torch import nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -209,7 +209,7 @@ def forward(self, x: Tensor) -> Tensor: class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", @@ -219,7 +219,7 @@ class MobileNet_V2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3a98456416d..91e1ea91a94 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _mobilenet_v3( class MobileNet_V3_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5483032, @@ -328,7 +328,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5483032, @@ -343,7 +343,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2542856, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index a506224d4b3..244d2b2fac1 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,7 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ...transforms import OpticalFlowEval, InterpolationMode +from ...transforms._presets import OpticalFlow, InterpolationMode from ...utils import _log_api_usage_once from .._api import Weights, WeightsEnum from .._utils import handle_legacy_interface @@ -523,7 +523,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -538,7 +538,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -553,7 +553,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_V1 = Weights( # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -568,7 +568,7 @@ class Raft_Large_Weights(WeightsEnum): # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -581,7 +581,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_K_V1 = Weights( # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -596,7 +596,7 @@ class Raft_Large_Weights(WeightsEnum): # Same as CT_SKHT with extra fine-tuning on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -612,7 +612,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth) url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, @@ -626,7 +626,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index befc2299c06..9944e470352 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.nn import functional as F -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -109,7 +109,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class GoogLeNet_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 697d99d4027..9a732f79fb7 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -175,7 +175,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class Inception_V3_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 40f5cb544fd..1def3d24b28 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights from ...ops.misc import Conv2dNormActivation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -67,7 +67,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class MobileNet_V2_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV2", diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4b79b7f26ae..4a203ca7095 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -157,7 +157,7 @@ def _mobilenet_v3_model( class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV3", diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 666b1b23163..ab512a7413f 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -13,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -161,7 +161,7 @@ def _resnet( class ResNet18_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -178,7 +178,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -191,7 +191,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -208,7 +208,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -221,7 +221,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index c5bfe698636..a3a26120479 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -6,7 +6,7 @@ from torch import Tensor from torchvision.models import shufflenetv2 -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -118,7 +118,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -133,7 +133,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 1015c21b858..72093686d84 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -416,7 +416,7 @@ def _regnet( class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4344144, @@ -427,7 +427,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 4344144, @@ -442,7 +442,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 6432512, @@ -453,7 +453,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 6432512, @@ -468,7 +468,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 11202430, @@ -479,7 +479,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 11202430, @@ -494,7 +494,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 19436338, @@ -505,7 +505,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 19436338, @@ -520,7 +520,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39381472, @@ -531,7 +531,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39381472, @@ -546,7 +546,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 83590140, @@ -557,7 +557,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 83590140, @@ -572,7 +572,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 145046770, @@ -583,7 +583,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 145046770, @@ -603,7 +603,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5495976, @@ -614,7 +614,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5495976, @@ -629,7 +629,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7259656, @@ -640,7 +640,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 7259656, @@ -655,7 +655,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 9190136, @@ -666,7 +666,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 9190136, @@ -681,7 +681,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 15296552, @@ -692,7 +692,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 15296552, @@ -707,7 +707,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39572648, @@ -718,7 +718,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39572648, @@ -733,7 +733,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 54278536, @@ -744,7 +744,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 54278536, @@ -759,7 +759,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 107811560, @@ -770,7 +770,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 107811560, diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 159749df006..8f44e553296 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -313,7 +313,7 @@ def _resnet( class ResNet18_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -330,7 +330,7 @@ class ResNet18_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -347,7 +347,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -360,7 +360,7 @@ class ResNet50_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -377,7 +377,7 @@ class ResNet50_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -390,7 +390,7 @@ class ResNet101_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -407,7 +407,7 @@ class ResNet101_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -420,7 +420,7 @@ class ResNet152_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -437,7 +437,7 @@ class ResNet152_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -450,7 +450,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -467,7 +467,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -480,7 +480,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -497,7 +497,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -510,7 +510,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -527,7 +527,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -540,7 +540,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 6e8bf0c398b..41ab34bae07 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -140,7 +140,7 @@ def _deeplabv3_resnet( class DeepLabV3_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 42004074, @@ -155,7 +155,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 60996202, @@ -170,7 +170,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 11029328, diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 5a3ca1f654f..6a760be36dc 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -59,7 +59,7 @@ def __init__(self, in_channels: int, channels: int) -> None: class FCN_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 35322218, @@ -74,7 +74,7 @@ class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 54314346, diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index d1fe15a350d..33684526c6b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES @@ -96,7 +96,7 @@ def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ "task": "image_semantic_segmentation", "architecture": "LRASPP", diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index b38c0ac2974..e988b819078 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -198,7 +198,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -212,7 +212,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index d495b3148e5..bde8b5efcfd 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -128,7 +128,7 @@ def _squeezenet( class SqueezeNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (21, 21), @@ -143,7 +143,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (17, 17), diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 27325c9016c..93bfd5e6ba3 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -120,7 +120,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132863336, @@ -134,7 +134,7 @@ class VGG11_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132868840, @@ -148,7 +148,7 @@ class VGG11_BN_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133047848, @@ -162,7 +162,7 @@ class VGG13_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133053736, @@ -176,7 +176,7 @@ class VGG13_BN_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138357544, @@ -190,7 +190,7 @@ class VGG16_Weights(WeightsEnum): IMAGENET1K_FEATURES = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), @@ -210,7 +210,7 @@ class VGG16_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138365992, @@ -224,7 +224,7 @@ class VGG16_BN_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143667240, @@ -238,7 +238,7 @@ class VGG19_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143678248, diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a6b779d10f1..618ddb96ba2 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor -from ...transforms import VideoClassificationEval, InterpolationMode +from ...transforms._presets import VideoClassification, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES @@ -322,7 +322,7 @@ def _video_resnet( class R3D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R3D", @@ -337,7 +337,7 @@ class R3D_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "MC3", @@ -352,7 +352,7 @@ class MC3_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R(2+1)D", diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 801e7adc981..fb34cf3c8e1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _vision_transformer( class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, @@ -334,7 +334,7 @@ class ViT_B_16_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, @@ -351,7 +351,7 @@ class ViT_B_32_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), + transforms=partial(ImageClassification, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, @@ -368,7 +368,7 @@ class ViT_L_16_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 94ec34ebe98..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,9 +1,2 @@ from .transforms import * from .autoaugment import * -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 1776d876ccb..0bfb1cf9b38 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -1,4 +1,8 @@ -from typing import Dict, Optional, Tuple +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -7,24 +11,22 @@ __all__ = [ - "ObjectDetectionEval", - "ImageClassificationEval", - "VideoClassificationEval", - "SemanticSegmentationEval", - "OpticalFlowEval", + "ObjectDetection", + "ImageClassification", + "VideoClassification", + "SemanticSegmentation", + "OpticalFlow", ] -class ObjectDetectionEval(nn.Module): - def forward( - self, img: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: +class ObjectDetection(nn.Module): + def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, torch.float), target + return F.convert_image_dtype(img, torch.float) -class ImageClassificationEval(nn.Module): +class ImageClassification(nn.Module): def __init__( self, crop_size: int, @@ -50,7 +52,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class VideoClassificationEval(nn.Module): +class VideoClassification(nn.Module): def __init__( self, crop_size: Tuple[int, int], @@ -67,55 +69,59 @@ def __init__( self._interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: - vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) + need_squeeze = False + if vid.ndim < 5: + vid = vid.unsqueeze(dim=0) + need_squeeze = True + + vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) + N, T, C, H, W = vid.shape + vid = vid.view(-1, C, H, W) vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self._mean, std=self._std) - return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) + vid = vid.view(N, T, C, H, W) + vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) + + if need_squeeze: + vid = vid.squeeze(dim=0) + return vid -class SemanticSegmentationEval(nn.Module): +class SemanticSegmentation(nn.Module): def __init__( self, resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation - self._interpolation_target = interpolation_target - def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + def forward(self, img: Tensor) -> Tensor: if isinstance(self._size, list): img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - if target: - if isinstance(self._size, list): - target = F.resize(target, self._size, interpolation=self._interpolation_target) - if not isinstance(target, Tensor): - target = F.pil_to_tensor(target) - target = target.squeeze(0).to(torch.int64) - return img, target - + return img -class OpticalFlowEval(nn.Module): - def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) +class OpticalFlow(nn.Module): + def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) - img1 = F.convert_image_dtype(img1, torch.float32) - img2 = F.convert_image_dtype(img2, torch.float32) + img1 = F.convert_image_dtype(img1, torch.float) + img2 = F.convert_image_dtype(img2, torch.float) # map [0, 1] into [-1, 1] img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) @@ -124,19 +130,4 @@ def forward( img1 = img1.contiguous() img2 = img2.contiguous() - return img1, img2, flow, valid_flow_mask - - def _pil_or_numpy_to_tensor( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - if not isinstance(img1, Tensor): - img1 = F.pil_to_tensor(img1) - if not isinstance(img2, Tensor): - img2 = F.pil_to_tensor(img2) - - if flow is not None and not isinstance(flow, Tensor): - flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): - valid_flow_mask = torch.from_numpy(valid_flow_mask) - - return img1, img2, flow, valid_flow_mask + return img1, img2