Skip to content

Commit cf88959

Browse files
gchananPenghuiCheng
authored andcommitted
Move minimal wrapdim functionality to core, remove THTensor include i… (pytorch#11283)
Summary: …n TensorImpl. Pull Request resolved: pytorch#11283 Reviewed By: ezyang Differential Revision: D9660015 Pulled By: gchanan fbshipit-source-id: 263cba226d9ee981d55281c94e6fda5842a46b02
1 parent 04cb5a5 commit cf88959

File tree

4 files changed

+31
-27
lines changed

4 files changed

+31
-27
lines changed

aten/src/ATen/TensorImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
#include <ATen/core/optional.h>
66
#include <ATen/Context.h>
77
#include <ATen/core/Backend.h>
8+
#include <ATen/core/WrapDimMinimal.h>
89

910
#include <ATen/detail/VariableHooksInterface.h>
1011

11-
#include <TH/THTensor.hpp>
12-
1312
namespace at {
1413

1514
Type& TensorImpl::type() const {

aten/src/ATen/WrapDimUtils.h

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,10 @@
11
#pragma once
22

3+
#include "ATen/core/WrapDimMinimal.h"
34
#include "ATen/TensorImpl.h"
4-
#include <sstream>
55

66
namespace at {
77

8-
static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
9-
if (dim_post_expr <= 0) {
10-
if (!wrap_scalar) {
11-
std::ostringstream oss;
12-
oss << "dimension specified as " << dim << " but tensor has no dimensions";
13-
throw std::runtime_error(oss.str());
14-
}
15-
dim_post_expr = 1; // this will make range [-1, 0]
16-
}
17-
18-
int64_t min = -dim_post_expr;
19-
int64_t max = dim_post_expr - 1;
20-
AT_CHECK(
21-
dim >= min && dim <= max,
22-
"Dimension out of range (expected to be in range of [",
23-
min, ", ", max, "], but got ", dim, ")");
24-
if (dim < 0) dim += dim_post_expr;
25-
return dim;
26-
}
27-
288
static inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl *tensor) {
299
return maybe_wrap_dim(dim, tensor->dim());
3010
}

aten/src/ATen/core/WrapDimMinimal.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include "ATen/core/Error.h"
4+
5+
namespace at {
6+
7+
static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
8+
if (dim_post_expr <= 0) {
9+
AT_CHECK(wrap_scalar, "dimension specified as ", dim, " but tensor has no dimensions");
10+
dim_post_expr = 1; // this will make range [-1, 0]
11+
}
12+
13+
int64_t min = -dim_post_expr;
14+
int64_t max = dim_post_expr - 1;
15+
AT_CHECK(
16+
dim >= min && dim <= max,
17+
"Dimension out of range (expected to be in range of [",
18+
min, ", ", max, "], but got ", dim, ")");
19+
if (dim < 0) dim += dim_post_expr;
20+
return dim;
21+
}
22+
23+
}

aten/src/ATen/test/native_test.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
using namespace at;
88

9+
using Catch::Matchers::StartsWith;
10+
911
#define REQUIRE_EQUAL(t1, t2) \
1012
REQUIRE(t1.equal(t2));
1113

@@ -73,10 +75,10 @@ void test(Type & T, Type & AccT) {
7375

7476
SECTION( "size / stride" ) {
7577
auto scalar = randn({}, T);
76-
REQUIRE_THROWS_WITH(scalar.size(0), "dimension specified as 0 but tensor has no dimensions");
77-
REQUIRE_THROWS_WITH(scalar.size(-1), "dimension specified as -1 but tensor has no dimensions");
78-
REQUIRE_THROWS_WITH(scalar.stride(0), "dimension specified as 0 but tensor has no dimensions");
79-
REQUIRE_THROWS_WITH(scalar.stride(-1), "dimension specified as -1 but tensor has no dimensions");
78+
REQUIRE_THROWS_WITH(scalar.size(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
79+
REQUIRE_THROWS_WITH(scalar.size(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
80+
REQUIRE_THROWS_WITH(scalar.stride(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
81+
REQUIRE_THROWS_WITH(scalar.stride(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
8082

8183
auto empty = randn({0}, T);
8284
REQUIRE(empty.size(0) == 0);

0 commit comments

Comments
 (0)