Skip to content

Commit 912dafb

Browse files
authored
Merge branch 'main' into check-rotate-autodiff
2 parents 7bd8d65 + 55f7faf commit 912dafb

File tree

29 files changed

+390
-69
lines changed

29 files changed

+390
-69
lines changed

.circleci/unittest/linux/scripts/environment.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ dependencies:
99
- libpng
1010
- jpeg
1111
- ca-certificates
12-
# TODO: remove this after https://github.com/pytorch/pytorch/issues/69905 is resolved
13-
- pyyaml
1412
- pip:
1513
- future
1614
- pillow >=5.3.0, !=8.3.*

.circleci/unittest/windows/scripts/environment.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ dependencies:
99
- libpng
1010
- jpeg
1111
- ca-certificates
12-
# TODO: remove this after https://github.com/pytorch/pytorch/issues/69905 is resolved
13-
- pyyaml
1412
- pip:
1513
- future
1614
- pillow >=5.3.0, !=8.3.*

references/classification/sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
1515
https://github.com/facebookresearch/deit/blob/main/samplers.py
1616
"""
1717

18-
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
18+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
1919
if num_replicas is None:
2020
if not dist.is_available():
2121
raise RuntimeError("Requires distributed package to be available!")
@@ -28,11 +28,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
2828
self.num_replicas = num_replicas
2929
self.rank = rank
3030
self.epoch = 0
31-
self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
31+
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
3232
self.total_size = self.num_samples * self.num_replicas
3333
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
3434
self.shuffle = shuffle
3535
self.seed = seed
36+
self.repetitions = repetitions
3637

3738
def __iter__(self):
3839
# Deterministically shuffle based on epoch
@@ -44,7 +45,7 @@ def __iter__(self):
4445
indices = list(range(len(self.dataset)))
4546

4647
# Add extra samples to make it evenly divisible
47-
indices = [ele for ele in indices for i in range(3)]
48+
indices = [ele for ele in indices for i in range(self.repetitions)]
4849
indices += indices[: (self.total_size - len(indices))]
4950
assert len(indices) == self.total_size
5051

references/classification/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def load_data(traindir, valdir, args):
174174
print("Creating data loaders")
175175
if args.distributed:
176176
if args.ra_sampler:
177-
train_sampler = RASampler(dataset, shuffle=True)
177+
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
178178
else:
179179
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
180180
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
@@ -485,7 +485,10 @@ def get_args_parser(add_help=True):
485485
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
486486
)
487487
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
488-
parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training")
488+
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
489+
parser.add_argument(
490+
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
491+
)
489492

490493
# Prototype models only
491494
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ static void torch_jpeg_set_source_mgr(
7070
} // namespace
7171

7272
torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
73+
C10_LOG_API_USAGE_ONCE(
74+
"torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
7375
// Check that the input tensor dtype is uint8
7476
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
7577
// Check that the input tensor is 1-dimensional

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ torch::Tensor decode_png(
2323
const torch::Tensor& data,
2424
ImageReadMode mode,
2525
bool allow_16_bits) {
26+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
2627
// Check that the input tensor dtype is uint8
2728
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
2829
// Check that the input tensor is 1-dimensional

torchvision/csrc/io/image/cpu/encode_jpeg.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ using JpegSizeType = size_t;
2525
using namespace detail;
2626

2727
torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
28+
C10_LOG_API_USAGE_ONCE(
29+
"torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg");
2830
// Define compression structures and error handling
2931
struct jpeg_compress_struct cinfo {};
3032
struct torch_jpeg_error_mgr jerr {};

torchvision/csrc/io/image/cpu/encode_png.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ void torch_png_write_data(
6363
} // namespace
6464

6565
torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
66+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png");
6667
// Define compression structures and error handling
6768
png_structp png_write;
6869
png_infop info_ptr;

torchvision/csrc/io/image/cpu/read_write_file.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::wstring utf8_decode(const std::string& str) {
3333
#endif
3434

3535
torch::Tensor read_file(const std::string& filename) {
36+
C10_LOG_API_USAGE_ONCE(
37+
"torchvision.csrc.io.image.cpu.read_write_file.read_file");
3638
#ifdef _WIN32
3739
// According to
3840
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/stat-functions?view=vs-2019,
@@ -76,6 +78,8 @@ torch::Tensor read_file(const std::string& filename) {
7678
}
7779

7880
void write_file(const std::string& filename, torch::Tensor& data) {
81+
C10_LOG_API_USAGE_ONCE(
82+
"torchvision.csrc.io.image.cpu.read_write_file.write_file");
7983
// Check that the input tensor is on CPU
8084
TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
8185

torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ torch::Tensor decode_jpeg_cuda(
3333
const torch::Tensor& data,
3434
ImageReadMode mode,
3535
torch::Device device) {
36+
C10_LOG_API_USAGE_ONCE(
37+
"torchvision.csrc.io.image.cuda.decode_jpeg_cuda.decode_jpeg_cuda");
3638
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
3739

3840
TORCH_CHECK(

torchvision/csrc/io/video/video.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ void Video::_getDecoderParams(
157157
} // _get decoder params
158158

159159
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
160+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
160161
// set number of threads global
161162
numThreads_ = numThreads;
162163
// parse stream information

torchvision/csrc/io/video_reader/video_reader.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,8 @@ torch::List<torch::Tensor> read_video_from_memory(
583583
int64_t audioEndPts,
584584
int64_t audioTimeBaseNum,
585585
int64_t audioTimeBaseDen) {
586+
C10_LOG_API_USAGE_ONCE(
587+
"torchvision.csrc.io.video_reader.video_reader.read_video_from_memory");
586588
return readVideo(
587589
false,
588590
input_video,
@@ -627,6 +629,8 @@ torch::List<torch::Tensor> read_video_from_file(
627629
int64_t audioEndPts,
628630
int64_t audioTimeBaseNum,
629631
int64_t audioTimeBaseDen) {
632+
C10_LOG_API_USAGE_ONCE(
633+
"torchvision.csrc.io.video_reader.video_reader.read_video_from_file");
630634
torch::Tensor dummy_input_video = torch::ones({0});
631635
return readVideo(
632636
true,
@@ -653,10 +657,14 @@ torch::List<torch::Tensor> read_video_from_file(
653657
}
654658

655659
torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
660+
C10_LOG_API_USAGE_ONCE(
661+
"torchvision.csrc.io.video_reader.video_reader.probe_video_from_memory");
656662
return probeVideo(false, input_video, "");
657663
}
658664

659665
torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
666+
C10_LOG_API_USAGE_ONCE(
667+
"torchvision.csrc.io.video_reader.video_reader.probe_video_from_file");
660668
torch::Tensor dummy_input_video = torch::ones({0});
661669
return probeVideo(true, dummy_input_video, videoPath);
662670
}

torchvision/io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
from ..utils import _log_api_usage_once
56
from ._video_opt import (
67
Timebase,
78
VideoMetaData,
@@ -106,6 +107,7 @@ class VideoReader:
106107
"""
107108

108109
def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> None:
110+
_log_api_usage_once(self)
109111
if not _has_video_opt():
110112
raise RuntimeError(
111113
"Not compiled with video_reader support, "

torchvision/io/image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from ..extension import _load_library
7+
from ..utils import _log_api_usage_once
78

89

910
try:
@@ -41,6 +42,8 @@ def read_file(path: str) -> torch.Tensor:
4142
Returns:
4243
data (Tensor)
4344
"""
45+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
46+
_log_api_usage_once(read_file)
4447
data = torch.ops.image.read_file(path)
4548
return data
4649

@@ -54,6 +57,8 @@ def write_file(filename: str, data: torch.Tensor) -> None:
5457
filename (str): the path to the file to be written
5558
data (Tensor): the contents to be written to the output file
5659
"""
60+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
61+
_log_api_usage_once(write_file)
5762
torch.ops.image.write_file(filename, data)
5863

5964

@@ -74,6 +79,8 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
7479
Returns:
7580
output (Tensor[image_channels, image_height, image_width])
7681
"""
82+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
83+
_log_api_usage_once(decode_png)
7784
output = torch.ops.image.decode_png(input, mode.value, False)
7885
return output
7986

@@ -93,6 +100,8 @@ def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
93100
Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
94101
PNG file.
95102
"""
103+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
104+
_log_api_usage_once(encode_png)
96105
output = torch.ops.image.encode_png(input, compression_level)
97106
return output
98107

@@ -109,6 +118,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
109118
compression_level (int): Compression factor for the resulting file, it must be a number
110119
between 0 and 9. Default: 6
111120
"""
121+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
122+
_log_api_usage_once(write_png)
112123
output = encode_png(input, compression_level)
113124
write_file(filename, output)
114125

@@ -137,6 +148,8 @@ def decode_jpeg(
137148
Returns:
138149
output (Tensor[image_channels, image_height, image_width])
139150
"""
151+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
152+
_log_api_usage_once(decode_jpeg)
140153
device = torch.device(device)
141154
if device.type == "cuda":
142155
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
@@ -160,6 +173,8 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
160173
output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
161174
JPEG file.
162175
"""
176+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
177+
_log_api_usage_once(encode_jpeg)
163178
if quality < 1 or quality > 100:
164179
raise ValueError("Image quality should be a positive number between 1 and 100")
165180

@@ -178,6 +193,8 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
178193
quality (int): Quality of the resulting JPEG file, it must be a number
179194
between 1 and 100. Default: 75
180195
"""
196+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
197+
_log_api_usage_once(write_jpeg)
181198
output = encode_jpeg(input, quality)
182199
write_file(filename, output)
183200

@@ -201,6 +218,8 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
201218
Returns:
202219
output (Tensor[image_channels, image_height, image_width])
203220
"""
221+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
222+
_log_api_usage_once(decode_image)
204223
output = torch.ops.image.decode_image(input, mode.value)
205224
return output
206225

@@ -221,6 +240,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
221240
Returns:
222241
output (Tensor[image_channels, image_height, image_width])
223242
"""
243+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
244+
_log_api_usage_once(read_image)
224245
data = read_file(path)
225246
return decode_image(data, mode)
226247

torchvision/io/video.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import torch
1111

12+
from ..utils import _log_api_usage_once
1213
from . import _video_opt
1314

1415

@@ -77,6 +78,8 @@ def write_video(
7778
audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
7879
audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
7980
"""
81+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
82+
_log_api_usage_once(write_video)
8083
_check_av_available()
8184
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
8285

@@ -256,6 +259,8 @@ def read_video(
256259
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
257260
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
258261
"""
262+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
263+
_log_api_usage_once(read_video)
259264

260265
from torchvision import get_video_backend
261266

@@ -374,6 +379,8 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in
374379
video_fps (float, optional): the frame rate for the video
375380
376381
"""
382+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
383+
_log_api_usage_once(read_video_timestamps)
377384
from torchvision import get_video_backend
378385

379386
if get_video_backend() != "pyav":

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import pathlib
34
import re
@@ -8,7 +9,6 @@
89
from torchdata.datapipes.iter import (
910
IterDataPipe,
1011
Mapper,
11-
Shuffler,
1212
Filter,
1313
IterKeyZipper,
1414
)
@@ -20,7 +20,7 @@
2020
OnlineResource,
2121
DatasetType,
2222
)
23-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
23+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
2424
from torchvision.prototype.features import Label, BoundingBox, Feature
2525

2626

@@ -121,7 +121,7 @@ def _make_datapipe(
121121

122122
images_dp = Filter(images_dp, self._is_not_background_image)
123123
images_dp = hint_sharding(images_dp)
124-
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
124+
images_dp = hint_shuffling(images_dp)
125125

126126
anns_dp = Filter(anns_dp, self._is_ann)
127127

@@ -133,7 +133,7 @@ def _make_datapipe(
133133
buffer_size=INFINITE_BUFFER_SIZE,
134134
keep_key=True,
135135
)
136-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
136+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
137137

138138
def _generate_categories(self, root: pathlib.Path) -> List[str]:
139139
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
@@ -185,8 +185,8 @@ def _make_datapipe(
185185
dp = resource_dps[0]
186186
dp = Filter(dp, self._is_not_rogue_file)
187187
dp = hint_sharding(dp)
188-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
189-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
188+
dp = hint_shuffling(dp)
189+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
190190

191191
def _generate_categories(self, root: pathlib.Path) -> List[str]:
192192
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)

0 commit comments

Comments
 (0)