Skip to content

Commit 1cd11db

Browse files
committed
Revert "Cleaner semantics for Reserve (pytorch#10261)"
This reverts commit e0d4357.
1 parent 25b2e88 commit 1cd11db

File tree

7 files changed

+44
-94
lines changed

7 files changed

+44
-94
lines changed

caffe2/core/tensor.h

Lines changed: 19 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,6 @@ class Tensor {
231231

232232
virtual ~Tensor() noexcept {}
233233

234-
/**
235-
* @brief Extend the outer-most dimension of this tensor
236-
* to dimension of `num`.
237-
*/
238-
void ExtendTo(TIndex num, float growthPct, BaseContext* context) {
239-
CAFFE_ENFORCE_GE_WITH_CALLER(dims_.size(), 1);
240-
CAFFE_ENFORCE_GE_WITH_CALLER(growthPct, 0);
241-
CAFFE_ENFORCE(context != nullptr, "Context must be provided.");
242-
Extend(num - dims_[0], growthPct, context);
243-
}
244-
245234
/**
246235
* @brief Extends the outer-most dimension of this tensor by num elements,
247236
* preserving the existing data.
@@ -253,8 +242,6 @@ class Tensor {
253242
*/
254243
void Extend(TIndex num, float growthPct, BaseContext* context) {
255244
CAFFE_ENFORCE_GE_WITH_CALLER(dims_.size(), 1);
256-
CAFFE_ENFORCE_GE_WITH_CALLER(
257-
num, 0, "`num` must be non-negative for Extend");
258245
auto newDims = dims_;
259246
newDims[0] += num;
260247
if (!data_) {
@@ -274,17 +261,30 @@ class Tensor {
274261
auto newCapacity = dims_;
275262
newCapacity[0] = std::max<size_t>(
276263
newDims[0], std::ceil(dims_[0] * (growthPct + 100) / 100));
264+
Reserve(newCapacity, context);
265+
dims_ = newDims;
266+
size_ = newSize;
267+
}
268+
269+
template <class T>
270+
void Reserve(const std::vector<T>& newCapacity, BaseContext* context) {
271+
auto newSize = std::accumulate(
272+
newCapacity.begin(),
273+
newCapacity.end(),
274+
static_cast<TIndex>(1),
275+
std::multiplies<TIndex>());
276+
if (newSize * meta_.itemsize() <= capacity_) {
277+
return;
278+
}
277279
auto oldData = std::move(data_);
278280
auto oldSize = size_;
279281
auto oldDims = dims_;
280282
Resize(newCapacity);
281283
auto* newData = raw_mutable_data(meta_);
282-
CAFFE_ENFORCE(
283-
context != nullptr, "Context must be provided to Extend the tensor");
284284
context->CopyItemsSameDevice(meta_, oldSize, oldData.get(), newData);
285+
dims_ = oldDims;
286+
size_ = oldSize;
285287
reserved_ = true;
286-
dims_ = newDims;
287-
size_ = newSize;
288288
}
289289

290290
/**
@@ -293,7 +293,7 @@ class Tensor {
293293
* This method guarantees that no re-allocations are carried out, which means
294294
* that the extra capacity after the end of the shurnk tensor is maintained.
295295
*/
296-
void ShrinkTo(TIndex outer_dim) {
296+
void Shrink(TIndex outer_dim) {
297297
CAFFE_ENFORCE_WITH_CALLER(dims_.size() >= 1, "Tensor must be at least 1D");
298298
CAFFE_ENFORCE_WITH_CALLER(
299299
outer_dim <= dims_[0],
@@ -306,38 +306,6 @@ class Tensor {
306306
std::multiplies<TIndex>());
307307
}
308308

309-
/**
310-
* @brief Reserve space for the underlying tensor.
311-
*
312-
* This must be called after Resize(), since we only specify the first
313-
* dimension This does not copy over the old data to the newly allocated space
314-
*/
315-
template <class T>
316-
void ReserveSpace(const T& outer_dim) {
317-
CAFFE_ENFORCE(
318-
size_ != -1, "size should be initialized before calling ReserveSpace");
319-
auto newCapacity = dims_;
320-
newCapacity[0] = outer_dim;
321-
auto newSize = std::accumulate(
322-
newCapacity.begin(),
323-
newCapacity.end(),
324-
static_cast<TIndex>(1),
325-
std::multiplies<TIndex>());
326-
if (newSize * meta_.itemsize() <= capacity_) {
327-
return;
328-
}
329-
// Old data is discarded
330-
data_.reset();
331-
auto oldSize = size_;
332-
auto oldDims = dims_;
333-
Resize(newCapacity);
334-
// Allocate new memory and don't copy over the data
335-
raw_mutable_data(meta_);
336-
dims_ = oldDims;
337-
size_ = oldSize;
338-
reserved_ = true;
339-
}
340-
341309
/**
342310
* @brief Resizes a tensor.
343311
*
@@ -421,7 +389,7 @@ class Tensor {
421389
capacity_ = 0;
422390
// If reserved is true and we changed tensor memory then it is fine
423391
// to switch it to false, if Resize is called from Reserve and it triggers
424-
// FreeMemory() then reserved_ will be set to true at end of ReserveSpace()
392+
// FreeMemory() then reserved_ will be set to true at end of Reserve()
425393
reserved_ = false;
426394
}
427395

@@ -772,10 +740,6 @@ class Tensor {
772740
TypeMeta meta_;
773741
std::shared_ptr<void> data_;
774742
size_t capacity_ = 0;
775-
// we decide to keep reserved and it will
776-
// live in Tensor after the split
777-
// The logic is that if Extend() or ReserveSpace() were ever called,
778-
// then subsequent Resize()s will not free up Storage.
779743
bool reserved_ = false;
780744
DeviceType device_type_ = CPU;
781745
// In case of chunk load we store how much data was already loaded

caffe2/experiments/operators/tt_pad_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class TTPadGradientOp final : public Operator<Context> {
8383
auto dim1 = G.dim(1);
8484

8585
if (old_dim0 < new_dim0) {
86-
output->ShrinkTo(old_dim0);
86+
output->Shrink(old_dim0);
8787
}
8888

8989
return true;

caffe2/mobile/contrib/ulp2/ulp_neon.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ void run2b1bConvIm2ColGEMM(QConvState* state,
537537
} else {
538538
CAFFE_ENFORCE_EQ(Y->dim32(0), divRoundUp(X.dim32(0) * OH * OW, kGEMMTileSize) * kGEMMTileSize);
539539
CAFFE_ENFORCE_EQ(Y->dim32(1), OC);
540-
Y->ShrinkTo(X.dim32(0) * OH * OW);
540+
Y->Shrink(X.dim32(0) * OH * OW);
541541
Y->Reshape(std::vector<TIndex>{{TIndex(X.dim(0)), TIndex(OH), TIndex(OW), TIndex(OC)}});
542542
}
543543
}

caffe2/operators/dataset_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ class TrimDatasetOp : public Operator<CPUContext> {
10041004
// trim each column to the offset
10051005
for (int col = 0; col < walker.fields().size(); ++col) {
10061006
auto newOuterSize = walker.fields().at(col).offset();
1007-
Output(col)->ShrinkTo(newOuterSize);
1007+
Output(col)->Shrink(newOuterSize);
10081008
}
10091009
return true;
10101010
}

caffe2/operators/last_n_window_collector.cc

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class LastNWindowCollectorOp : public Operator<Context> {
4646
}
4747
}
4848

49-
auto num_entries = input.dims()[0];
49+
auto dims = input.dims();
50+
auto num_entries = dims[0];
5051

5152
if (OutputSize() > NUM_VISITED) {
5253
auto* num_visited_tensor = Output(NUM_VISITED);
@@ -59,14 +60,8 @@ class LastNWindowCollectorOp : public Operator<Context> {
5960
*num_visited += num_entries;
6061
}
6162

62-
if (!output_initialized) {
63-
auto dims = input.dims();
64-
dims[0] = 0;
65-
output->Resize(dims);
66-
// pass meta to output
67-
output->raw_mutable_data(input.meta());
68-
output->ReserveSpace(numToCollect_);
69-
}
63+
dims[0] = numToCollect_;
64+
output->Reserve(dims, &context_);
7065

7166
if (num_entries == 0) {
7267
if (!output_initialized) {
@@ -78,14 +73,10 @@ class LastNWindowCollectorOp : public Operator<Context> {
7873

7974
auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
8075
auto output_batch_size = output_initialized ? output->dim(0) : 0;
81-
auto output_num =
82-
std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
83-
84-
// output_num is >= output_batch_size
85-
if (output_num > output_batch_size) {
86-
output->ExtendTo(output_num, 50, &context_);
76+
dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
77+
if (output_batch_size < numToCollect_) {
78+
output->Resize(dims);
8779
}
88-
8980
auto* output_data =
9081
static_cast<char*>(output->raw_mutable_data(input.meta()));
9182

caffe2/operators/reservoir_sampling.cc

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,23 @@ class ReservoirSamplingOp final : public Operator<Context> {
3939
}
4040
}
4141

42-
auto num_entries = input.dims()[0];
42+
auto dims = input.dims();
43+
auto num_entries = dims[0];
4344

44-
if (!output_initialized) {
45-
// IMPORTANT: Force the output to have the right type before reserving,
46-
// so that the output gets the right capacity
47-
auto dims = input.dims();
48-
dims[0] = 0;
49-
output->Resize(dims);
50-
output->raw_mutable_data(input.meta());
51-
output->ReserveSpace(numToCollect_);
52-
}
45+
dims[0] = numToCollect_;
46+
// IMPORTANT: Force the output to have the right type before reserving,
47+
// so that the output gets the right capacity
48+
output->raw_mutable_data(input.meta());
49+
output->Reserve(dims, &context_);
5350

5451
auto* pos_to_object =
5552
OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
5653
if (pos_to_object) {
5754
if (!output_initialized) {
5855
// Cleaning up in case the reservoir got reset.
5956
pos_to_object->Resize(0);
60-
pos_to_object->template mutable_data<int64_t>();
61-
pos_to_object->ReserveSpace(numToCollect_);
6257
}
58+
pos_to_object->Reserve(std::vector<TIndex>{numToCollect_}, &context_);
6359
}
6460

6561
auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
@@ -100,14 +96,13 @@ class ReservoirSamplingOp final : public Operator<Context> {
10096
const auto num_new_entries = countNewEntries(unique_object_ids);
10197
auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
10298
auto output_batch_size = output_initialized ? output->dim(0) : 0;
103-
auto output_num =
104-
std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
105-
// output_num is >= output_batch_size
106-
output->ExtendTo(output_num, 50, &context_);
107-
if (pos_to_object) {
108-
pos_to_object->ExtendTo(output_num, 50, &context_);
99+
dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
100+
if (output_batch_size < numToCollect_) {
101+
output->Resize(dims);
102+
if (pos_to_object) {
103+
pos_to_object->Resize(dims[0]);
104+
}
109105
}
110-
111106
auto* output_data =
112107
static_cast<char*>(output->raw_mutable_data(input.meta()));
113108
auto* pos_to_object_data = pos_to_object

caffe2/operators/text_file_reader.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class TextFileReaderReadOp : public Operator<CPUContext> {
150150
}
151151

152152
for (int i = 0; i < numFields; ++i) {
153-
Output(i)->ShrinkTo(rowsRead);
153+
Output(i)->Shrink(rowsRead);
154154
}
155155
return true;
156156
}

0 commit comments

Comments
 (0)