Skip to content

Commit 2573b16

Browse files
authored
Merge branch 'main' into add-sun397-datapipe
2 parents ddb1887 + 93104c1 commit 2573b16

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+291
-238
lines changed

.git-blame-ignore-revs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# This file keeps git blame clean.
2+
# See https://docs.github.com/en/repositories/working-with-files/using-files/viewing-a-file#ignore-commits-in-the-blame-view
3+
4+
# Add ufmt (usort + black) as code formatter (#4384)
5+
5f0edb97b46e5bff71dc19dedef05c5396eeaea2
6+
# update python syntax >=3.6 (#4585)
7+
d367a01a18a3ae6bee13d8be3b63fd6a581ea46f

test/builtin_dataset_mocks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pathlib
1111
import pickle
1212
import random
13+
import warnings
1314
import xml.etree.ElementTree as ET
1415
from collections import defaultdict, Counter
1516

@@ -470,7 +471,10 @@ def imagenet(info, root, config):
470471
]
471472
num_children = 1
472473
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
473-
savemat(data_root / "meta.mat", dict(synsets=synsets))
474+
with warnings.catch_warnings():
475+
# The warning is not for savemat, but rather for some internals savemet is using
476+
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
477+
savemat(data_root / "meta.mat", dict(synsets=synsets))
474478

475479
make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
476480
else: # config.split == "test"

test/test_functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class TestRotate:
6767
IMG_W = 26
6868

6969
@pytest.mark.parametrize("device", cpu_and_gpu())
70-
@pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)])
70+
@pytest.mark.parametrize("height, width", [(7, 33), (26, IMG_W), (32, IMG_W)])
7171
@pytest.mark.parametrize(
7272
"center",
7373
[
@@ -77,7 +77,7 @@ class TestRotate:
7777
],
7878
)
7979
@pytest.mark.parametrize("dt", ALL_DTYPES)
80-
@pytest.mark.parametrize("angle", range(-180, 180, 17))
80+
@pytest.mark.parametrize("angle", range(-180, 180, 34))
8181
@pytest.mark.parametrize("expand", [True, False])
8282
@pytest.mark.parametrize(
8383
"fill",

test/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,24 +745,24 @@ def test_detection_model_validation(model_fn):
745745
x = [torch.rand(input_shape)]
746746

747747
# validate that targets are present in training
748-
with pytest.raises(ValueError):
748+
with pytest.raises(AssertionError):
749749
model(x)
750750

751751
# validate type
752752
targets = [{"boxes": 0.0}]
753-
with pytest.raises(TypeError):
753+
with pytest.raises(AssertionError):
754754
model(x, targets=targets)
755755

756756
# validate boxes shape
757757
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
758758
targets = [{"boxes": boxes}]
759-
with pytest.raises(ValueError):
759+
with pytest.raises(AssertionError):
760760
model(x, targets=targets)
761761

762762
# validate that no degenerate boxes are present
763763
boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]])
764764
targets = [{"boxes": boxes}]
765-
with pytest.raises(ValueError):
765+
with pytest.raises(AssertionError):
766766
model(x, targets=targets)
767767

768768

test/test_models_detection_anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_incorrect_anchors(self):
1616
image1 = torch.randn(3, 800, 800)
1717
image_list = ImageList(image1, [(800, 800)])
1818
feature_maps = [torch.randn(1, 50)]
19-
pytest.raises(ValueError, anc, image_list, feature_maps)
19+
pytest.raises(AssertionError, anc, image_list, feature_maps)
2020

2121
def _init_test_anchor_generator(self):
2222
anchor_sizes = ((10,),)

test/test_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ def test_autocast(self, x_dtype, rois_dtype):
138138

139139
def _helper_boxes_shape(self, func):
140140
# test boxes as Tensor[N, 5]
141-
with pytest.raises(ValueError):
141+
with pytest.raises(AssertionError):
142142
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
143143
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
144144
func(a, boxes, output_size=(2, 2))
145145

146146
# test boxes as List[Tensor[N, 4]]
147-
with pytest.raises(ValueError):
147+
with pytest.raises(AssertionError):
148148
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
149149
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
150150
ops.roi_pool(a, [boxes], output_size=(2, 2))

test/test_prototype_builtin_datasets.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@
77
import torch
88
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
99
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
10-
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
1110
from torch.utils.data.graph import traverse
12-
from torchdata.datapipes.iter import IterDataPipe, Shuffler
11+
from torch.utils.data.graph_settings import get_all_graph_pipes
12+
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
1313
from torchvision._utils import sequence_to_str
1414
from torchvision.prototype import transforms, datasets
15+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
1516
from torchvision.prototype.features import Image, Label
1617

1718
assert_samples_equal = functools.partial(
1819
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
1920
)
2021

2122

23+
def extract_datapipes(dp):
24+
return get_all_graph_pipes(traverse(dp, only_datapipe=True))
25+
26+
2227
@pytest.fixture
2328
def test_home(mocker, tmp_path):
2429
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
@@ -35,6 +40,7 @@ def test_coverage():
3540
)
3641

3742

43+
@pytest.mark.filterwarnings("error")
3844
class TestCommon:
3945
@parametrize_dataset_mocks(DATASET_MOCKS)
4046
def test_smoke(self, test_home, dataset_mock, config):
@@ -118,19 +124,18 @@ def test_serializable(self, test_home, dataset_mock, config):
118124

119125
pickle.dumps(dataset)
120126

127+
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
128+
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
129+
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
121130
@parametrize_dataset_mocks(DATASET_MOCKS)
122131
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
123132
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
124-
def scan(graph):
125-
for node, sub_graph in graph.items():
126-
yield node
127-
yield from scan(sub_graph)
128133

129134
dataset_mock.prepare(test_home, config)
130135

131136
dataset = datasets.load(dataset_mock.name, **config)
132137

133-
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
138+
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
134139
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
135140

136141
@parametrize_dataset_mocks(DATASET_MOCKS)
@@ -144,6 +149,17 @@ def test_save_load(self, test_home, dataset_mock, config):
144149
buffer.seek(0)
145150
assert_samples_equal(torch.load(buffer), sample)
146151

152+
@parametrize_dataset_mocks(DATASET_MOCKS)
153+
def test_infinite_buffer_size(self, test_home, dataset_mock, config):
154+
dataset_mock.prepare(test_home, config)
155+
dataset = datasets.load(dataset_mock.name, **config)
156+
157+
for dp in extract_datapipes(dataset):
158+
if hasattr(dp, "buffer_size"):
159+
# TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is
160+
# resolved
161+
assert dp.buffer_size == INFINITE_BUFFER_SIZE
162+
147163

148164
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
149165
class TestQMNIST:

torchvision/csrc/io/decoder/decoder.cpp

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,6 @@ constexpr size_t kIoBufferSize = 96 * 1024;
1818
constexpr size_t kIoPaddingSize = AV_INPUT_BUFFER_PADDING_SIZE;
1919
constexpr size_t kLogBufferSize = 1024;
2020

21-
int ffmpeg_lock(void** mutex, enum AVLockOp op) {
22-
std::mutex** handle = (std::mutex**)mutex;
23-
switch (op) {
24-
case AV_LOCK_CREATE:
25-
*handle = new std::mutex();
26-
break;
27-
case AV_LOCK_OBTAIN:
28-
(*handle)->lock();
29-
break;
30-
case AV_LOCK_RELEASE:
31-
(*handle)->unlock();
32-
break;
33-
case AV_LOCK_DESTROY:
34-
delete *handle;
35-
break;
36-
}
37-
return 0;
38-
}
39-
4021
bool mapFfmpegType(AVMediaType media, MediaType* type) {
4122
switch (media) {
4223
case AVMEDIA_TYPE_AUDIO:
@@ -202,8 +183,6 @@ void Decoder::initOnce() {
202183
avcodec_register_all();
203184
#endif
204185
avformat_network_init();
205-
// register ffmpeg lock manager
206-
av_lockmgr_register(&ffmpeg_lock);
207186
av_log_set_callback(Decoder::logFunction);
208187
av_log_set_level(AV_LOG_ERROR);
209188
VLOG(1) << "Registered ffmpeg libs";
@@ -277,7 +256,7 @@ bool Decoder::init(
277256
break;
278257
}
279258

280-
fmt = av_find_input_format(fmtName);
259+
fmt = (AVInputFormat*)av_find_input_format(fmtName);
281260
}
282261

283262
const size_t avioCtxBufferSize = kIoBufferSize;
@@ -495,8 +474,8 @@ void Decoder::cleanUp() {
495474

496475
// function does actual work, derived class calls it in working thread
497476
// periodically. On success method returns 0, ENODATA on EOF, ETIMEDOUT if
498-
// no frames got decoded in the specified timeout time, and error on
499-
// unrecoverable error.
477+
// no frames got decoded in the specified timeout time, AVERROR_BUFFER_TOO_SMALL
478+
// when unable to allocate packet and error on unrecoverable error
500479
int Decoder::getFrame(size_t workingTimeInMs) {
501480
if (inRange_.none()) {
502481
return ENODATA;
@@ -505,10 +484,15 @@ int Decoder::getFrame(size_t workingTimeInMs) {
505484
// once decode() method gets called and grab some bytes
506485
// run this method again
507486
// init package
508-
AVPacket avPacket;
509-
av_init_packet(&avPacket);
510-
avPacket.data = nullptr;
511-
avPacket.size = 0;
487+
// update 03/22: moving memory management to ffmpeg
488+
AVPacket* avPacket;
489+
avPacket = av_packet_alloc();
490+
if (avPacket == nullptr) {
491+
LOG(ERROR) << "decoder as not able to allocate the packet.";
492+
return AVERROR_BUFFER_TOO_SMALL;
493+
}
494+
avPacket->data = nullptr;
495+
avPacket->size = 0;
512496

513497
auto end = std::chrono::steady_clock::now() +
514498
std::chrono::milliseconds(workingTimeInMs);
@@ -520,8 +504,12 @@ int Decoder::getFrame(size_t workingTimeInMs) {
520504
int result = 0;
521505
size_t decodingErrors = 0;
522506
bool decodedFrame = false;
523-
while (!interrupted_ && inRange_.any() && !decodedFrame && watcher()) {
524-
result = av_read_frame(inputCtx_, &avPacket);
507+
while (!interrupted_ && inRange_.any() && !decodedFrame) {
508+
if (watcher() == false) {
509+
result = ETIMEDOUT;
510+
break;
511+
}
512+
result = av_read_frame(inputCtx_, avPacket);
525513
if (result == AVERROR(EAGAIN)) {
526514
VLOG(4) << "Decoder is busy...";
527515
std::this_thread::yield();
@@ -538,10 +526,11 @@ int Decoder::getFrame(size_t workingTimeInMs) {
538526
break;
539527
}
540528

541-
// get stream
542-
auto stream = findByIndex(avPacket.stream_index);
529+
// get stream; if stream cannot be found reset the packet to
530+
// default settings
531+
auto stream = findByIndex(avPacket->stream_index);
543532
if (stream == nullptr || !inRange_.test(stream->getIndex())) {
544-
av_packet_unref(&avPacket);
533+
av_packet_unref(avPacket);
545534
continue;
546535
}
547536

@@ -553,7 +542,7 @@ int Decoder::getFrame(size_t workingTimeInMs) {
553542
bool hasMsg = false;
554543
// packet either got consumed completely or not at all
555544
if ((result = processPacket(
556-
stream, &avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
545+
stream, avPacket, &gotFrame, &hasMsg, params_.fastSeek)) < 0) {
557546
LOG(ERROR) << "processPacket failed with code: " << result;
558547
break;
559548
}
@@ -585,20 +574,18 @@ int Decoder::getFrame(size_t workingTimeInMs) {
585574

586575
result = 0;
587576

588-
av_packet_unref(&avPacket);
577+
av_packet_unref(avPacket);
589578
}
590579

591-
av_packet_unref(&avPacket);
592-
580+
av_packet_free(&avPacket);
593581
VLOG(2) << "Interrupted loop"
594582
<< ", interrupted_ " << interrupted_ << ", inRange_.any() "
595583
<< inRange_.any() << ", decodedFrame " << decodedFrame << ", result "
596584
<< result;
597585

598586
// loop can be terminated, either by:
599587
// 1. explcitly iterrupted
600-
// 2. terminated by workable timeout
601-
// 3. unrecoverable error or ENODATA (end of stream)
588+
// 3. unrecoverable error or ENODATA (end of stream) or ETIMEDOUT (timeout)
602589
// 4. decoded frames pts are out of the specified range
603590
// 5. success decoded frame
604591
if (interrupted_) {

torchvision/csrc/io/decoder/stream.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Stream::~Stream() {
2828

2929
// look up the proper CODEC querying the function
3030
AVCodec* Stream::findCodec(AVCodecParameters* params) {
31-
return avcodec_find_decoder(params->codec_id);
31+
return (AVCodec*)avcodec_find_decoder(params->codec_id);
3232
}
3333

3434
// Allocate memory for the AVCodecContext, which will hold the context for

torchvision/csrc/io/decoder/subtitle_stream.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,34 @@ int SubtitleStream::initFormat() {
4343
int SubtitleStream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
4444
// clean-up
4545
releaseSubtitle();
46+
47+
// FIXME: should this even be created?
48+
AVPacket* avPacket;
49+
avPacket = av_packet_alloc();
50+
if (avPacket == nullptr) {
51+
LOG(ERROR)
52+
<< "decoder as not able to allocate the subtitle-specific packet.";
53+
// alternative to ENOMEM
54+
return AVERROR_BUFFER_TOO_SMALL;
55+
}
56+
avPacket->data = nullptr;
57+
avPacket->size = 0;
4658
// check flush packet
47-
AVPacket avPacket;
48-
av_init_packet(&avPacket);
49-
avPacket.data = nullptr;
50-
avPacket.size = 0;
51-
auto pkt = packet ? *packet : avPacket;
59+
auto pkt = packet ? packet : avPacket;
60+
5261
int gotFramePtr = 0;
53-
int result = avcodec_decode_subtitle2(codecCtx_, &sub_, &gotFramePtr, &pkt);
62+
// is these a better way than cast from const?
63+
int result =
64+
avcodec_decode_subtitle2(codecCtx_, &sub_, &gotFramePtr, (AVPacket*)pkt);
5465

5566
if (result < 0) {
5667
LOG(ERROR) << "avcodec_decode_subtitle2 failed, err: "
5768
<< Util::generateErrorDesc(result);
69+
// free the packet we've created
70+
av_packet_free(&avPacket);
5871
return result;
5972
} else if (result == 0) {
60-
result = pkt.size; // discard the rest of the package
73+
result = pkt->size; // discard the rest of the package
6174
}
6275

6376
sub_.release = gotFramePtr;
@@ -66,9 +79,10 @@ int SubtitleStream::analyzePacket(const AVPacket* packet, bool* gotFrame) {
6679
// set proper pts in us
6780
if (gotFramePtr) {
6881
sub_.pts = av_rescale_q(
69-
pkt.pts, inputCtx_->streams[format_.stream]->time_base, timeBaseQ);
82+
pkt->pts, inputCtx_->streams[format_.stream]->time_base, timeBaseQ);
7083
}
7184

85+
av_packet_free(&avPacket);
7286
return result;
7387
}
7488

0 commit comments

Comments
 (0)