@@ -20,32 +20,24 @@ std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) {
20
20
}
21
21
22
22
void checkDim (CheckedFrom c, const TensorGeometryArg& t, int64_t dim) {
23
- if (t->dim () != dim) {
24
- std::ostringstream oss;
25
- oss << " Expected " << dim << " -dimensional tensor, but got "
26
- << t->dim () << " -dimensional tensor for " << t
27
- << " (while checking arguments for " << c << " )" ;
28
- throw std::runtime_error (oss.str ());
29
- }
23
+ AT_CHECK (t->dim () == dim,
24
+ " Expected " , dim, " -dimensional tensor, but got " , t->dim (),
25
+ " -dimensional tensor for " , t," (while checking arguments for " , c, " )" );
30
26
}
31
27
32
28
void checkDimRange (CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) {
33
- if (t->dim () < dim_start || t->dim () >= dim_end) {
34
- std::ostringstream oss;
35
- oss << " Expected " << dim_start << " to " << (dim_end - 1 ) << " dimensions, but got "
36
- << t->dim () << " -dimensional tensor for " << t
37
- << " (while checking arguments for " << c << " )" ;
38
- throw std::runtime_error (oss.str ());
39
- }
29
+ AT_CHECK (
30
+ t->dim () >= dim_start && t->dim () < dim_end,
31
+ " Expected " , dim_start, " to " , (dim_end - 1 ), " dimensions, but got " ,
32
+ t->dim (), " -dimensional tensor for " , t, " (while checking arguments for " ,
33
+ c, " )" );
40
34
}
41
35
42
36
void checkContiguous (CheckedFrom c, const TensorGeometryArg& t) {
43
- if (!t->is_contiguous ()) {
44
- std::ostringstream oss;
45
- oss << " Expected contiguous tensor, but got non-contiguous tensor for " << t
46
- << " (while checking arguments for " << c << " )" ;
47
- throw std::runtime_error (oss.str ());
48
- }
37
+ AT_CHECK (
38
+ t->is_contiguous (),
39
+ " Expected contiguous tensor, but got non-contiguous tensor for " , t,
40
+ " (while checking arguments for " , c, " )" );
49
41
}
50
42
51
43
void checkAllContiguous (CheckedFrom c, at::ArrayRef<TensorArg> ts) {
@@ -57,23 +49,18 @@ void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) {
57
49
58
50
void checkSize (CheckedFrom c, const TensorGeometryArg& t, IntList sizes) {
59
51
checkDim (c, t, sizes.size ());
60
- if (!t->sizes ().equals (sizes)) {
61
- std::ostringstream oss;
62
- oss << " Expected tensor of size " << sizes << " , but got tensor of size "
63
- << t->sizes () << " for " << t
64
- << " (while checking arguments for " << c << " )" ;
65
- throw std::runtime_error (oss.str ());
66
- }
52
+ AT_CHECK (
53
+ t->sizes ().equals (sizes),
54
+ " Expected tensor of size " , sizes, " , but got tensor of size " , t->sizes (),
55
+ " for " , t, " (while checking arguments for " , c, " )" );
67
56
}
68
57
69
58
void checkSize (CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) {
70
- if (t->size (dim) != size) {
71
- std::ostringstream oss;
72
- oss << " Expected tensor to have size " << size << " at dimension " << dim
73
- << " , but got size " << t->size (dim) << " for " << t
74
- << " (while checking arguments for " << c << " )" ;
75
- throw std::runtime_error (oss.str ());
76
- }
59
+ AT_CHECK (
60
+ t->size (dim) == size,
61
+ " Expected tensor to have size " , size, " at dimension " , dim,
62
+ " , but got size " , t->size (dim), " for " , t,
63
+ " (while checking arguments for " , c, " )" );
77
64
}
78
65
79
66
void checkAllSame (CheckedFrom c, ArrayRef<TensorArg> tensors, void (*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) {
@@ -89,37 +76,32 @@ void checkAllSame(CheckedFrom c, ArrayRef<TensorArg> tensors, void(*fn)(CheckedF
89
76
}
90
77
91
78
void checkSameSize (CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
92
- if (!t1->sizes ().equals (t2->sizes ())) {
93
- std::ostringstream oss;
94
- oss << " Expected tensor for " << t1 << " to have same size as tensor for "
95
- << t2 << " ; but " << t1->sizes () << " does not equal " << t2->sizes ()
96
- << " (while checking arguments for " << c << " )" ;
97
- throw std::runtime_error (oss.str ());
98
- }
79
+ AT_CHECK (
80
+ t1->sizes ().equals (t2->sizes ()),
81
+ " Expected tensor for " , t1, " to have same size as tensor for " , t2,
82
+ " ; but " , t1->sizes (), " does not equal " , t2->sizes (),
83
+ " (while checking arguments for " , c, " )" );
99
84
}
100
85
101
86
void checkAllSameSize (CheckedFrom c, ArrayRef<TensorArg> tensors) {
102
87
checkAllSame (c, tensors, checkSameSize);
103
88
}
104
89
105
90
void checkNumel (CheckedFrom c, const TensorGeometryArg& t, int64_t numel) {
106
- if (t->numel () != numel) {
107
- std::ostringstream oss;
108
- oss << " Expected tensor for " << t << " to have "
109
- << numel << " elements; but it actually has " << t->numel () << " elements"
110
- << " (while checking arguments for " << c << " )" ;
111
- throw std::runtime_error (oss.str ());
112
- }
91
+ AT_CHECK (
92
+ t->numel () == numel,
93
+ " Expected tensor for " , t, " to have " , numel,
94
+ " elements; but it actually has " , t->numel (), " elements" ,
95
+ " (while checking arguments for " , c, " )" );
113
96
}
114
97
115
98
void checkSameNumel (CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
116
- if (t1->numel () != t2->numel ()) {
117
- std::ostringstream oss;
118
- oss << " Expected tensor for " << t1 << " to have same number of elements as tensor for "
119
- << t2 << " ; but " << t1->numel () << " does not equal " << t2->numel ()
120
- << " (while checking arguments for " << c << " )" ;
121
- throw std::runtime_error (oss.str ());
122
- }
99
+ AT_CHECK (
100
+ t1->numel () == t2->numel (),
101
+ " Expected tensor for " , t1,
102
+ " to have same number of elements as tensor for " , t2, " ; but " ,
103
+ t1->numel (), " does not equal " , t2->numel (),
104
+ " (while checking arguments for " , c, " )" );
123
105
}
124
106
125
107
void checkAllSameNumel (CheckedFrom c, ArrayRef<TensorArg> tensors) {
@@ -136,42 +118,34 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
136
118
oss << " Tensor for " << t2 << " is on CPU, " ;
137
119
}
138
120
oss << " but expected " << ((!(t1->is_cuda () || t2->is_cuda ())) ? " them" : " it" )
139
- << " to be on GPU (while checking arguments for " << c << " )" ;
140
- throw std::runtime_error (oss.str ());
141
- }
142
- if (t1->get_device () != t2->get_device ()) {
143
- std::ostringstream oss;
144
- oss << " Expected tensor for " << t1 << " to have the same device as "
145
- << " tensor for " << t2 << " ; but device " << t1->get_device () << " "
146
- << " does not equal " << t2->get_device ()
147
- << " (while checking arguments for " << c << " )" ;
148
- throw std::runtime_error (oss.str ());
121
+ << " to be on GPU (while checking arguments for " << c << " )" ;
122
+ AT_ERROR (oss.str ());
149
123
}
124
+ AT_CHECK (
125
+ t1->get_device () == t2->get_device (),
126
+ " Expected tensor for " , t1, " to have the same device as tensor for " , t2,
127
+ " ; but device " , t1->get_device (), " does not equal " , t2->get_device (),
128
+ " (while checking arguments for " , c, " )" );
150
129
}
151
130
152
131
void checkAllSameGPU (CheckedFrom c, ArrayRef<TensorArg> tensors) {
153
132
checkAllSame (c, tensors, checkSameGPU);
154
133
}
155
134
156
135
void checkSameType (CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
157
- if (t1->type () != t2->type ()) {
158
- std::ostringstream oss;
159
- oss << " Expected tensor for " << t1 << " to have the same type as "
160
- << " tensor for " << t2 << " ; but type " << t1->toString () << " "
161
- << " does not equal " << t2->toString ()
162
- << " (while checking arguments for " << c << " )" ;
163
- throw std::runtime_error (oss.str ());
164
- }
136
+ AT_CHECK (
137
+ t1->type () == t2->type (),
138
+ " Expected tensor for " , t1, " to have the same type as tensor for " , t2,
139
+ " ; but type " , t1->toString (), " does not equal " , t2->toString (),
140
+ " (while checking arguments for " , c, " )" );
165
141
}
166
142
167
143
void checkScalarType (CheckedFrom c, const TensorArg& t, ScalarType ty) {
168
- if (t->type ().scalarType () != ty) {
169
- std::ostringstream oss;
170
- oss << " Expected tensor for " << t << " to have scalar type "
171
- << toString (ty) << " ; but got " << t->toString ()
172
- << " instead (while checking arguments for " << c << " )" ;
173
- throw std::runtime_error (oss.str ());
174
- }
144
+ AT_CHECK (
145
+ t->type ().scalarType () == ty,
146
+ " Expected tensor for " , t, " to have scalar type " , toString (ty),
147
+ " ; but got " , t->toString (), " instead (while checking arguments for " , c,
148
+ " )" );
175
149
}
176
150
177
151
void checkScalarTypes (CheckedFrom c, const TensorArg& t,
@@ -190,7 +164,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
190
164
}
191
165
oss << " ; but got " << t->toString ()
192
166
<< " instead (while checking arguments for " << c << " )" ;
193
- throw std::runtime_error (oss.str ());
167
+ AT_ERROR (oss.str ());
194
168
}
195
169
}
196
170
@@ -199,24 +173,18 @@ void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
199
173
}
200
174
201
175
void checkSameDim (CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) {
202
- if (t1->dim () != t2->dim ()) {
203
- std::ostringstream oss;
204
- oss << " Expected tensor for " << t1 << " to have the same dimension as "
205
- << " tensor for " << t2 << " ; but " << t1->dim () << " "
206
- << " does not equal " << t2->dim ()
207
- << " (while checking arguments for " << c << " )" ;
208
- throw std::runtime_error (oss.str ());
209
- }
176
+ AT_CHECK (
177
+ t1->dim () == t2->dim (),
178
+ " Expected tensor for " , t1, " to have the same dimension as tensor for " ,
179
+ t2, " ; but " , t1->dim (), " does not equal " , t2->dim (),
180
+ " (while checking arguments for " , c, " )" );
210
181
}
211
182
212
183
void checkDefined (CheckedFrom c, const TensorArg& t) {
213
- if (!t->defined ()) {
214
- std::ostringstream oss;
215
- oss << " Expected tensor for " << t << " to be non-null, "
216
- << " but it was undefined "
217
- << " (while checking arguments for " << c << " )" ;
218
- throw std::runtime_error (oss.str ());
219
- }
184
+ AT_CHECK (
185
+ t->defined (),
186
+ " Expected tensor for " , t, " to be non-null, but it was undefined " ,
187
+ " (while checking arguments for " , c, " )" );
220
188
}
221
189
222
190
void checkAllDefined (CheckedFrom c, ArrayRef<TensorArg> ts) {
@@ -227,13 +195,11 @@ void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {
227
195
}
228
196
229
197
void checkBackend (CheckedFrom c, const Tensor& t, Backend backend) {
230
- if (t.type ().backend () != backend) {
231
- std::ostringstream oss;
232
- oss << " Expected tensor to have " << toString (backend) << " Backend, but got tensor with "
233
- << toString (t.type ().backend ()) << " Backend "
234
- << " (while checking arguments for " << c << " )" ;
235
- throw std::runtime_error (oss.str ());
236
- }
198
+ AT_CHECK (
199
+ t.type ().backend () == backend,
200
+ " Expected tensor to have " , toString (backend),
201
+ " Backend, but got tensor with " , toString (t.type ().backend ()), " Backend " ,
202
+ " (while checking arguments for " , c, " )" );
237
203
}
238
204
239
205
void checkBackend (CheckedFrom c, ArrayRef<Tensor> tensors, at::Backend backend) {
0 commit comments