Skip to content

Commit cac9107

Browse files
gchananfacebook-github-bot
authored andcommitted
Remove THTensor::_dim, temporarily remove THTensor_nDimension. (#9895)
Summary: Pull Request resolved: pytorch/pytorch#9895 The primary goal here was to remove THTensor::_dim, which isn't part of the API moving forward. Instead, we provide 3 options for getting the dimensionality (this is temporary although non-trivial to remove!): ``` nDimension corresponds to the "true" ATen dimension. TODO: implement. nDimensionLegacyNoScalars correpsonds to the ATen dimension, except scalars are viewed as 1-dimensional tensors. nDimensionLegacyAll corresponds to the ATen dimension, except scalars are viewed as 1-dimensional tensors and tensors with a dimension of size zero are collapsed to 0-dimensional tensors. ``` So in this patch, nDimension -> nDimensionLegacyNoScalars and _dim/_nDimension goes to nDimensionLegacyAll. These are just codemods. Pull Request resolved: pytorch/pytorch#9835 Reviewed By: ezyang Differential Revision: D8999338 Pulled By: gchanan fbshipit-source-id: a4d676ac728f6f36ca09604a41e888d545ae9311
1 parent 46b7df7 commit cac9107

File tree

87 files changed

+442
-424
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+442
-424
lines changed

aten/src/TH/THTensor.hpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,6 @@ struct THTensor
5252
return storage_->unsafe_data<T>() + storage_offset_;
5353
}
5454

55-
// [NOTE: _dim() vs dim()]
56-
// _dim() returns the "old" TH dimension view where no dimensions represents an empty tensor.
57-
// dim() returns the ATen view of the dimensionality, i.e. 0-sized dimensions are supported.
58-
inline int64_t _dim() const {
59-
return is_empty() ? 0 : dim();
60-
}
61-
6255
inline int64_t dim() const {
6356
return sizes_.size();
6457
}
@@ -159,6 +152,31 @@ inline void THTensor_setIsZeroDim(THTensor *tensor, bool is_zero_dim) {
159152
tensor->is_zero_dim_ = is_zero_dim;
160153
}
161154

155+
// [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
156+
// nDimension corresponds to the "true" ATen dimension. TODO: implement.
157+
// nDimensionLegacyNoScalars correpsonds to the ATen dimension, except scalars are viewed as 1-dimensional tensors.
158+
// nDimensionLegacyAll corresponds to the ATen dimension, except scalars are viewed as 1-dimensional tensors
159+
// and tensors with a dimension of size zero are collapsed to 0-dimensional tensors.
160+
//
161+
// Eventually, everything should go through nDimension or tensor->dim().
162+
inline int THTensor_nDimensionLegacyNoScalars(const THTensor* tensor) {
163+
if (THTensor_isZeroDim(tensor)) {
164+
return 1;
165+
} else {
166+
return tensor->dim();
167+
}
168+
}
169+
170+
inline int THTensor_nDimensionLegacyAll(const THTensor* tensor) {
171+
if (tensor->is_empty()) {
172+
return 0;
173+
} else if (THTensor_isZeroDim(tensor)) {
174+
return 1;
175+
} else {
176+
return tensor->dim();
177+
}
178+
}
179+
162180
TH_API void THTensor_free(THTensor *self);
163181
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
164182
at::IntList newshape);

aten/src/TH/THTensorApply.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
TENSOR##_data = THTensor_getStoragePtr(TENSOR)->data<TYPE>()+TENSOR->storage_offset(); \
4747
TENSOR##_size = 1; \
4848
TENSOR##_stride = 1; \
49-
for(TENSOR##_i = TENSOR->_dim()-1; TENSOR##_i >= 0; TENSOR##_i--) { \
49+
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-1; TENSOR##_i >= 0; TENSOR##_i--) { \
5050
if(TENSOR->size(TENSOR##_i) != 1) { \
5151
if(TENSOR->stride(TENSOR##_i) == TENSOR##_size && TENSOR##_i != DIM) \
5252
TENSOR##_size *= TENSOR->size(TENSOR##_i); \
@@ -59,7 +59,7 @@
5959
if (!TENSOR##_contiguous) { \
6060
/* Find the dimension of contiguous sections */ \
6161
TENSOR##_dim = 1; \
62-
for(TENSOR##_i = TENSOR->_dim()-2; TENSOR##_i >= 0; TENSOR##_i--) \
62+
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-2; TENSOR##_i >= 0; TENSOR##_i--) \
6363
{ \
6464
if(TENSOR->stride(TENSOR##_i) != TENSOR->stride(TENSOR##_i+1) * TENSOR->size(TENSOR##_i+1) || TENSOR##_i == DIM || TENSOR##_i+1 == DIM) \
6565
TENSOR##_dim++; \
@@ -69,19 +69,19 @@
6969
TENSOR##_sizes = TENSOR##_counter + TENSOR##_dim; \
7070
TENSOR##_strides = TENSOR##_counter + 2*TENSOR##_dim; \
7171
TH_TENSOR_dim_index = TENSOR##_dim-1; \
72-
TENSOR##_dimOffset = (DIM == TENSOR->_dim()-1) ? &TENSOR##_i : &TENSOR##_counter[DIM]; \
73-
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(TENSOR->_dim()-1); \
74-
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride(TENSOR->_dim()-1); \
72+
TENSOR##_dimOffset = (DIM == THTensor_nDimensionLegacyAll(TENSOR)-1) ? &TENSOR##_i : &TENSOR##_counter[DIM]; \
73+
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(THTensor_nDimensionLegacyAll(TENSOR)-1); \
74+
TENSOR##_strides[TH_TENSOR_dim_index] = TENSOR->stride(THTensor_nDimensionLegacyAll(TENSOR)-1); \
7575
/* TENSOR##_counter tracks where we are in the storage. The offset into the */ \
7676
/* storage is given by storage_offset + (i * j), where i is the stride */ \
7777
/* vector and j is tensor_counter vector. This sets the starting position for the loop. */ \
7878
for(TENSOR##_i = TENSOR##_dim-1; TENSOR##_i >= 0; --TENSOR##_i) { \
7979
TENSOR##_counter[TENSOR##_i] = 0; \
8080
} \
81-
for(TENSOR##_i = TENSOR->_dim()-2; TENSOR##_i >= 0; --TENSOR##_i) { \
81+
for(TENSOR##_i = THTensor_nDimensionLegacyAll(TENSOR)-2; TENSOR##_i >= 0; --TENSOR##_i) { \
8282
if (TENSOR->stride(TENSOR##_i) == TENSOR->stride(TENSOR##_i+1) * TENSOR->size(TENSOR##_i+1) && TENSOR##_i != DIM && TENSOR##_i+1 != DIM) { \
8383
TENSOR##_sizes[TH_TENSOR_dim_index] = TENSOR->size(TENSOR##_i) * TENSOR##_sizes[TH_TENSOR_dim_index]; \
84-
if (DIM != TENSOR->_dim()-1 && TENSOR##_i < DIM) \
84+
if (DIM != THTensor_nDimensionLegacyAll(TENSOR)-1 && TENSOR##_i < DIM) \
8585
TENSOR##_dimOffset--; \
8686
} else { \
8787
--TH_TENSOR_dim_index; \

aten/src/TH/THTensorDimApply.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
int TH_TENSOR_DIM_APPLY_i; \
147147
\
148148
if( (DIMENSION < 0) || (DIMENSION >= TENSOR1->dim()) ) \
149-
THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, TENSOR1->_dim()); \
149+
THError("invalid dimension %d (expected to be 0 <= dim < %d)", DIMENSION, THTensor_nDimensionLegacyAll(TENSOR1)); \
150150
if( TENSOR1->dim() != TENSOR2->dim() ) { \
151151
AT_ERROR("inconsistent tensor size, expected ", #TENSOR1, " ", TENSOR1->sizes(), " and ", #TENSOR2, " ", TENSOR2->sizes(), " to have the same number of dimensions"); \
152152
} \
@@ -266,33 +266,33 @@
266266
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
267267
int TH_TENSOR_DIM_APPLY_i; \
268268
\
269-
if( (DIMENSION < 0) || (DIMENSION >= TENSOR->_dim()) ) \
269+
if( (DIMENSION < 0) || (DIMENSION >= THTensor_nDimensionLegacyAll(TENSOR)) ) \
270270
THError("invalid dimension"); \
271271
\
272272
TENSOR##_data = THTensor_getStoragePtr(TENSOR)->data<TYPE>()+(TENSOR)->storage_offset(); \
273273
TENSOR##_stride = (TENSOR)->stride(DIMENSION); \
274274
TENSOR##_size = TENSOR->size(DIMENSION); \
275275
/* Counter stores the indices into the Tensor at any time */ \
276-
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR->_dim())); \
277-
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->_dim(); TH_TENSOR_DIM_APPLY_i++) \
276+
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(THTensor_nDimensionLegacyAll(TENSOR))); \
277+
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR); TH_TENSOR_DIM_APPLY_i++) \
278278
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
279279
\
280280
while(!TH_TENSOR_DIM_APPLY_hasFinished) \
281281
{ \
282282
CODE \
283283
\
284-
if(TENSOR->_dim() == 1) \
284+
if(THTensor_nDimensionLegacyAll(TENSOR) == 1) \
285285
break; \
286286
\
287-
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR->_dim(); TH_TENSOR_DIM_APPLY_i++) \
287+
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR); TH_TENSOR_DIM_APPLY_i++) \
288288
{ \
289289
/* Check if the index is equal to DIMENSION. We don't need to update the */ \
290290
/* offset if this is the case, and can consider the next index. However, */ \
291291
/* in the case that the DIMENSION is the last index in the Tensor, then */ \
292292
/* we have parsed the entire tensor and can exit */ \
293293
if(TH_TENSOR_DIM_APPLY_i == DIMENSION) \
294294
{ \
295-
if(TH_TENSOR_DIM_APPLY_i == TENSOR->_dim()-1) \
295+
if(TH_TENSOR_DIM_APPLY_i == THTensor_nDimensionLegacyAll(TENSOR)-1) \
296296
{ \
297297
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
298298
break; \
@@ -307,7 +307,7 @@
307307
if(TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR->size(TH_TENSOR_DIM_APPLY_i)) \
308308
{ \
309309
/* Handled TENSOR_size(dim) iterations for DIM_APPLY_i. If this is the last dimension, exit */ \
310-
if(TH_TENSOR_DIM_APPLY_i == TENSOR->_dim()-1) \
310+
if(TH_TENSOR_DIM_APPLY_i == THTensor_nDimensionLegacyAll(TENSOR)-1) \
311311
{ \
312312
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
313313
break; \

aten/src/TH/generic/THTensor.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@ ptrdiff_t THTensor_(storageOffset)(const THTensor *self)
1515
return self->storage_offset();
1616
}
1717

18-
int THTensor_(nDimension)(const THTensor *self)
18+
int THTensor_(nDimensionLegacyNoScalars)(const THTensor *self)
1919
{
20-
return self->dim();
20+
return THTensor_nDimensionLegacyNoScalars(self);
2121
}
2222

23-
int THTensor_(_nDimension)(const THTensor *self)
23+
int THTensor_(nDimensionLegacyAll)(const THTensor *self)
2424
{
25-
return self->_dim();
25+
return THTensor_nDimensionLegacyAll(self);
2626
}
2727

2828
int64_t THTensor_(size)(const THTensor *self, int dim)
2929
{
3030
THArgCheck((dim >= 0) && (dim < self->dim()), 2, "dimension %d out of range of %dD tensor",
31-
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
31+
dim+TH_INDEX_BASE, THTensor_(nDimensionLegacyNoScalars)(self));
3232
return self->size(dim);
3333
}
3434

3535
int64_t THTensor_(stride)(const THTensor *self, int dim)
3636
{
3737
THArgCheck((dim >= 0) && (dim < self->dim()), 2, "dimension %d out of range of %dD tensor",
38-
dim+TH_INDEX_BASE, THTensor_(nDimension)(self));
38+
dim+TH_INDEX_BASE, THTensor_(nDimensionLegacyNoScalars)(self));
3939
return self->stride(dim);
4040
}
4141

@@ -397,7 +397,7 @@ void THTensor_(select)(THTensor *self, THTensor *src, int dimension, int64_t sli
397397
src = self;
398398

399399
#ifndef USE_TH_SIZE_ZERO_DIM
400-
THArgCheck(src->_dim() > 1, 1, "cannot select on a vector");
400+
THArgCheck(THTensor_nDimensionLegacyAll(src) > 1, 1, "cannot select on a vector");
401401
#else
402402
#ifndef USE_TH_SCALAR
403403
THArgCheck(src->dim() > 1, 1, "cannot select on a vector");
@@ -575,7 +575,7 @@ int THTensor_(isTransposed)(const THTensor *self)
575575
int64_t size_max_stride = 1;
576576
int64_t z = 1;
577577
int d;
578-
for (d = 0; d < self->_dim(); ++d) {
578+
for (d = 0; d < THTensor_nDimensionLegacyAll(self); ++d) {
579579
if (self->stride(d) == 0 && self->size(d) != 1)
580580
return 0;
581581
if (self->stride(d) > max_stride) {
@@ -611,10 +611,10 @@ int THTensor_(isContiguous)(const THTensor *self)
611611
int THTensor_(isSize)(const THTensor *self, const THLongStorage *dims)
612612
{
613613
int d;
614-
if (self->_dim() != dims->size)
614+
if (THTensor_nDimensionLegacyAll(self) != dims->size)
615615
return 0;
616616

617-
for(d = 0; d < self->_dim(); ++d)
617+
for(d = 0; d < THTensor_nDimensionLegacyAll(self); ++d)
618618
{
619619
if(self->size(d) != THLongStorage_data(dims)[d])
620620
return 0;
@@ -641,10 +641,10 @@ int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
641641
return 0;
642642
if (THTensor_getStoragePtr(self) == THTensor_getStoragePtr(src) &&
643643
self->storage_offset() == src->storage_offset() &&
644-
self->_dim() == src->_dim())
644+
THTensor_nDimensionLegacyAll(self) == THTensor_nDimensionLegacyAll(src))
645645
{
646646
int d;
647-
for (d = 0; d < self->_dim(); ++d)
647+
for (d = 0; d < THTensor_nDimensionLegacyAll(self); ++d)
648648
{
649649
if (self->size(d) != src->size(d) || self->stride(d) != src->stride(d))
650650
return 0;
@@ -656,13 +656,13 @@ int THTensor_(isSetTo)(const THTensor *self, const THTensor* src)
656656

657657
ptrdiff_t THTensor_(nElement)(const THTensor *self)
658658
{
659-
if(self->_dim() == 0)
659+
if(THTensor_nDimensionLegacyAll(self) == 0)
660660
return 0;
661661
else
662662
{
663663
ptrdiff_t nElement = 1;
664664
int d;
665-
for(d = 0; d < self->_dim(); d++)
665+
for(d = 0; d < THTensor_nDimensionLegacyAll(self); d++)
666666
nElement *= self->size(d);
667667
return nElement;
668668
}
@@ -790,56 +790,56 @@ void THTensor_(resizeNd)(THTensor *self, int nDimension, int64_t *size, int64_t
790790

791791
void THTensor_(set1d)(THTensor *tensor, int64_t x0, real value)
792792
{
793-
THArgCheck(tensor->_dim() == 1, 1, "tensor must have one dimension");
793+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 1, 1, "tensor must have one dimension");
794794
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)), 2, "out of range");
795795
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0), value);
796796
}
797797

798798
real THTensor_(get1d)(const THTensor *tensor, int64_t x0)
799799
{
800-
THArgCheck(tensor->_dim() == 1, 1, "tensor must have one dimension");
800+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 1, 1, "tensor must have one dimension");
801801
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)), 2, "out of range");
802802
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0));
803803
}
804804

805805
void THTensor_(set2d)(THTensor *tensor, int64_t x0, int64_t x1, real value)
806806
{
807-
THArgCheck(tensor->_dim() == 2, 1, "tensor must have two dimensions");
807+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 2, 1, "tensor must have two dimensions");
808808
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
809809
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1), value);
810810
}
811811

812812
real THTensor_(get2d)(const THTensor *tensor, int64_t x0, int64_t x1)
813813
{
814-
THArgCheck(tensor->_dim() == 2, 1, "tensor must have two dimensions");
814+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 2, 1, "tensor must have two dimensions");
815815
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)), 2, "out of range");
816816
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1));
817817
}
818818

819819
void THTensor_(set3d)(THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, real value)
820820
{
821-
THArgCheck(tensor->_dim() == 3, 1, "tensor must have three dimensions");
821+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 3, 1, "tensor must have three dimensions");
822822
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
823823
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2), value);
824824
}
825825

826826
real THTensor_(get3d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2)
827827
{
828-
THArgCheck(tensor->_dim() == 3, 1, "tensor must have three dimensions");
828+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 3, 1, "tensor must have three dimensions");
829829
THArgCheck( (x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)), 2, "out of range");
830830
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2));
831831
}
832832

833833
void THTensor_(set4d)(THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3, real value)
834834
{
835-
THArgCheck(tensor->_dim() == 4, 1, "tensor must have four dimensions");
835+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 4, 1, "tensor must have four dimensions");
836836
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
837837
THStorage_(set)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3), value);
838838
}
839839

840840
real THTensor_(get4d)(const THTensor *tensor, int64_t x0, int64_t x1, int64_t x2, int64_t x3)
841841
{
842-
THArgCheck(tensor->_dim() == 4, 1, "tensor must have four dimensions");
842+
THArgCheck(THTensor_nDimensionLegacyAll(tensor) == 4, 1, "tensor must have four dimensions");
843843
THArgCheck((x0 >= 0) && (x0 < tensor->size(0)) && (x1 >= 0) && (x1 < tensor->size(1)) && (x2 >= 0) && (x2 < tensor->size(2)) && (x3 >= 0) && (x3 < tensor->size(3)), 2, "out of range");
844844
return THStorage_(get)(THTensor_getStoragePtr(tensor), tensor->storage_offset()+x0*tensor->stride(0)+x1*tensor->stride(1)+x2*tensor->stride(2)+x3*tensor->stride(3));
845845
}
@@ -853,10 +853,10 @@ THDescBuff THTensor_(desc)(const THTensor *tensor) {
853853
n += snprintf(str, L-n, "torch." _stringify(x) "Tensor of size ");
854854
#undef _stringify
855855
int i;
856-
for(i = 0; i < tensor->_dim(); i++) {
856+
for(i = 0; i < THTensor_nDimensionLegacyAll(tensor); i++) {
857857
if(n >= L) break;
858858
n += snprintf(str+n, L-n, "%" PRId64, tensor->size(i));
859-
if(i < tensor->_dim()-1) {
859+
if(i < THTensor_nDimensionLegacyAll(tensor)-1) {
860860
n += snprintf(str+n, L-n, "x");
861861
}
862862
}

aten/src/TH/generic/THTensor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ typedef struct THTensor THTensor;
2424
TH_API THStorage* THTensor_(storage)(const THTensor *self);
2525
TH_API ptrdiff_t THTensor_(storageOffset)(const THTensor *self);
2626

27-
// See [NOTE: _dim() vs dim()]; _nDimension corresponds to _dim(), nDimension corresponds to dim().
28-
TH_API int THTensor_(nDimension)(const THTensor *self);
29-
TH_API int THTensor_(_nDimension)(const THTensor *self);
27+
// See [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
28+
TH_API int THTensor_(nDimensionLegacyNoScalars)(const THTensor *self);
29+
TH_API int THTensor_(nDimensionLegacyAll)(const THTensor *self);
3030
TH_API int64_t THTensor_(size)(const THTensor *self, int dim);
3131
TH_API int64_t THTensor_(stride)(const THTensor *self, int dim);
3232
TH_API THLongStorage *THTensor_(newSizeOf)(THTensor *self);

aten/src/TH/generic/THTensorApply.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
#define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \
120120
{ \
121121
int shape_check_flag = 0; \
122-
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->_dim(); TH_TENSOR_DIM_APPLY_i++) \
122+
for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR1); TH_TENSOR_DIM_APPLY_i++) \
123123
{ \
124124
int64_t TENSOR3##_dim_size = TENSOR3->size(TH_TENSOR_DIM_APPLY_i); \
125125
if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \

aten/src/TH/generic/THTensorConv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,7 @@ void THTensor_(conv2Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
13671367

13681368
AT_CHECK(!t_->is_empty() && t_->dim() == 3, "input: non-empty 3D Tensor expected, got size: ", t_->sizes());
13691369
AT_CHECK(!k_->is_empty() && k_->dim() == 3, "kernel: non-empty 3D Tensor expected, got size: ", k_->sizes());
1370-
THArgCheck(map->_dim() == 2 , 4, "map: 2D Tensor expected");
1370+
THArgCheck(THTensor_nDimensionLegacyAll(map) == 2 , 4, "map: 2D Tensor expected");
13711371
THArgCheck(srow >= 1, 6, "Stride should be a positive integer");
13721372
THArgCheck(scol >= 1, 7, "Stride should be a positive integer");
13731373

@@ -1880,7 +1880,7 @@ void THTensor_(conv3Dmap)(THTensor *r_, real beta, real alpha, THTensor *t_, THT
18801880

18811881
AT_CHECK(!t_->is_empty() && t_->dim() == 4, "input: non-empty 4D Tensor expected, got size: ", t_->sizes());
18821882
AT_CHECK(!k_->is_empty() && k_->dim() == 4, "kernel: non-empty 4D Tensor expected, got size: ", k_->sizes());
1883-
THArgCheck(map->_dim() == 2 , 4, "map: 2D Tensor expected");
1883+
THArgCheck(THTensor_nDimensionLegacyAll(map) == 2 , 4, "map: 2D Tensor expected");
18841884
THArgCheck(srow >= 1, 6, "Stride should be a positive integer");
18851885
THArgCheck(scol >= 1, 7, "Stride should be a positive integer");
18861886
THArgCheck(*vf == 'V' || *vf == 'F', 8, "type of convolution can 'V' or 'F'");

aten/src/TH/generic/THTensorCopy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
1717
const int MIN_SZ = 60 * 60;
1818
return THTensor_(isContiguous)(tensor) &&
1919
!src->is_empty() &&
20-
THTensor_(nDimension)(src) == 2 &&
20+
THTensor_(nDimensionLegacyNoScalars)(src) == 2 &&
2121
THTensor_(stride)(src, 0) == 1 &&
2222
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
2323
THTensor_(nElement)(tensor) >= MIN_SZ;

0 commit comments

Comments
 (0)