Skip to content

Commit bb71244

Browse files
wanchaolRob Kunkle
authored and
Rob Kunkle
committed
Fix the clamp special case and gradient problem on None, add None to JIT (pytorch#9596)
Summary: Supersedes pytorch#8925 This PR fixes pytorch#8502, it fixes the gradients problem for clamp when passing None to the function, and add support for the NoneLiteral and NoneType in script to enable clamp tests. Now we could have corner cases like: ```python torch.jit.script def func(): x = torch.randn(3, 3, requires_grad=True) y = torch.clamp(x, None, 0) # max = 0 y = torch.clamp(x, min=None, max=0) ``` In both JIT and Aten, we use Scalar(NAN) as a sentinel value when passing None type to function clamp, this is the current way we used to support None type in JIT and to solve the gradient problem when user explicitly passing None into clamp. In JIT side, we create a tensor(NAN) and undefinedTensor if we encounter None when matching the function schema, and later in the interpreter, it will translate to Scalar(NAN) if needed. Ideally we don't need clamp_min and clamp_max in ATenNative/Autograd and could only support clamp after this change, but since bunch of other operators (e.g. Activation.cpp, Loss.cpp) is using clamp_min in several places, we will still have the functions available, but all python invocations will only call clamp instead of clamp_min/max (with calling underlying th_max/th_min in clamp). zdevito jamesr66a Pull Request resolved: pytorch#9596 Reviewed By: zdevito Differential Revision: D8940839 Pulled By: wanchaol fbshipit-source-id: c543a867b82e0ab8c99384773b173fdde2605d28
1 parent 2e48656 commit bb71244

23 files changed

+172
-164
lines changed

aten/src/ATen/native/UnaryOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ Tensor clamp_min(const Tensor& self, Scalar min) {
4646
}
4747

4848
Tensor& _clamp__cpu(Tensor& self, Scalar min, Scalar max) {
49-
return _th_clamp_(self, min, max);
49+
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
50+
return _th_clamp_(self, min, max);
51+
} else if (std::isnan(min.toDouble())) {
52+
return _th_clamp_max_(self, max);
53+
} else if (std::isnan(max.toDouble())) {
54+
return _th_clamp_min_(self, min);
55+
} else {
56+
return self;
57+
}
5058
}
5159

5260
Tensor& _clamp_out_cpu(
@@ -56,7 +64,14 @@ Tensor& _clamp_out_cpu(
5664
Scalar max) {
5765
result.resize_(self.sizes());
5866
result.copy_(self);
59-
return _th_clamp_(result, min, max);
67+
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
68+
_th_clamp_(result, min, max);
69+
} else if (std::isnan(min.toDouble())) {
70+
_th_clamp_max_(result, max);
71+
} else if (std::isnan(max.toDouble())) {
72+
_th_clamp_min_(result, min);
73+
}
74+
return result;
6075
}
6176

6277
Tensor& _clamp_max__cpu(Tensor& self, Scalar max) {

aten/src/ATen/native/cuda/CUDAUnaryOps.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
namespace at { namespace native {
44

55
Tensor& _clamp__cuda(Tensor& self, Scalar min, Scalar max) {
6-
return _th_clamp_(self, min, max);
6+
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
7+
return _th_clamp_(self, min, max);
8+
} else if (std::isnan(min.toDouble())) {
9+
return _th_clamp_max_(self, max);
10+
} else if (std::isnan(max.toDouble())) {
11+
return _th_clamp_min_(self, min);
12+
} else {
13+
return self;
14+
}
715
}
816

917
Tensor& _clamp_out_cuda(
@@ -13,7 +21,14 @@ Tensor& _clamp_out_cuda(
1321
Scalar max) {
1422
result.resize_(self.sizes());
1523
result.copy_(self);
16-
return _th_clamp_(result, min, max);
24+
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
25+
_th_clamp_(result, min, max);
26+
} else if (std::isnan(min.toDouble())) {
27+
_th_clamp_max_(result, max);
28+
} else if (std::isnan(max.toDouble())) {
29+
_th_clamp_min_(result, min);
30+
}
31+
return result;
1732
}
1833

1934
Tensor& _clamp_max__cuda(Tensor& self, Scalar max) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,26 @@
266266
- func: chunk(Tensor self, int64_t chunks, int64_t dim=0) -> TensorList
267267

268268
- func: clamp(Tensor self, Scalar min, Scalar max) -> Tensor
269+
python_default_init:
270+
min: NAN
271+
max: NAN
269272

270273
- func: clamp_(Tensor self, Scalar min, Scalar max) -> Tensor
271274
dispatch:
272275
CPU: _clamp__cpu
273276
CUDA: _clamp__cuda
277+
python_default_init:
278+
min: NAN
279+
max: NAN
274280

275281
- func: clamp_out(Tensor result, Tensor self, Scalar min, Scalar max) -> Tensor
276282
variants: function
277283
dispatch:
278284
CPU: _clamp_out_cpu
279285
CUDA: _clamp_out_cuda
286+
python_default_init:
287+
min: NAN
288+
max: NAN
280289

281290
- func: clamp_max(Tensor self, Scalar max) -> Tensor
282291

test/expect/TestScript.test_python_frontend.expect

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
(param (ident y) (tensor_type))
66
(param (ident z) (tensor_type)))
77
(list
8+
(assign
9+
(list (variable (ident q)))
10+
(=)
11+
(None))
812
(assign
913
(list (variable (ident q)))
1014
(=)

test/test_jit.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,7 @@ def to_int(x):
16291629

16301630
def test_python_frontend(self):
16311631
def fn(x, y, z):
1632+
q = None
16321633
q = x + y - z.sigmoid()
16331634
print(q)
16341635
w = -z
@@ -1862,6 +1863,28 @@ def test_script_for_in_range_if_ast(x):
18621863

18631864
self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
18641865

1866+
def test_script_None(self):
1867+
def func(x):
1868+
output = None
1869+
output = x
1870+
return output
1871+
1872+
self.checkScript(func, [torch.arange(0, 2)], optimize=True)
1873+
1874+
def test_script_clamp_none(self):
1875+
# TODO: could not enable default/optional argument for None in JIT
1876+
# result from Aten native python_default_init for clamp, it is used
1877+
# in Aten but not in JIT, need to fix type/default arg system in ATen
1878+
def test_script_clamp_max_none(x):
1879+
return torch.clamp(x, min=None, max=2)
1880+
1881+
def test_script_clamp_min_none(x):
1882+
return torch.clamp(x, min=2, max=None)
1883+
1884+
input = [torch.arange(0, 3)]
1885+
self.checkScript(test_script_clamp_max_none, input, optimize=True)
1886+
self.checkScript(test_script_clamp_min_none, input, optimize=True)
1887+
18651888
def test_script_bool_constant(self):
18661889
script = '''
18671890
def test_script_bool_constant():
@@ -4845,10 +4868,6 @@ def forward(self, x, y):
48454868

48464869
# known to be failing in script
48474870
EXCLUDE_SCRIPT = {
4848-
'test_clamp_max',
4849-
'test_clamp_max_scalar',
4850-
'test_clamp_min',
4851-
'test_clamp_min_scalar',
48524871
# TODO: Fix var/std
48534872
# there are two schemas for var (and std):
48544873
# (1) var(Tensor, int, *, bool, bool, Tensor)

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,10 @@
168168
- name: ceil(Tensor self)
169169
self: zeros_like(grad)
170170

171-
# For clamp, clamp_min, and clamp_max, gradient is not defined at the
172-
# boundaries. But empirically it's helpful to be able to get gradient on min and
173-
# max, so we return the subgradient 1 for these cases.
171+
# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
172+
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
174173
- name: clamp(Tensor self, Scalar min, Scalar max)
175-
self: grad * ((self >= min) * (self <= max)).type_as(grad)
174+
self: clamp_backward(grad, self, min, max)
176175

177176
- name: clamp_min(Tensor self, Scalar min)
178177
self: grad * (self >= min).type_as(grad)

tools/autograd/gen_python_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
# These functions require manual Python bindings or are not exposed to Python
1919
SKIP_PYTHON_BINDINGS = [
20-
'alias', 'contiguous', 'clamp.*', 'is_cuda', 'is_sparse', 'size', 'stride',
20+
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
2121
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
2222
'.*_forward_out', 'sparse_raw_resize_', '_unsafe_view', 'tensor',
2323
'sparse_coo_tensor', 'th_sparse_coo_tensor', 'native_sparse_coo_tensor',

tools/autograd/templates/Functions.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,17 @@ std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<
431431
return grad_inputs;
432432
}
433433

434+
Tensor clamp_backward(const Tensor & grad, const Tensor &self, const Scalar & min, const Scalar & max) {
435+
// clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
436+
if (std::isnan(min.toFloat())) {
437+
return grad * (self <= max).type_as(grad);
438+
} else if (std::isnan(max.toFloat())) {
439+
return grad * (self >= min).type_as(grad);
440+
} else {
441+
return grad * ((self >= min) * (self <= max)).type_as(grad);
442+
}
443+
}
444+
434445
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, IntList sizes, IntList strides, const Scalar & alpha) {
435446
// if input was column-major, return grad as column-order for efficiency
436447
if (strides[0] == 1 && strides[1] == sizes[0]) {

tools/autograd/templates/python_torch_functions.cpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -191,41 +191,6 @@ static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject
191191
END_HANDLE_TH_ERRORS
192192
}
193193

194-
// The Python clamp() syntax has to be mapped to one of three C++ functions
195-
static PyObject * THPVariable_clamp(PyObject* module, PyObject* args, PyObject* kwargs)
196-
{
197-
HANDLE_TH_ERRORS
198-
static PythonArgParser parser({
199-
"clamp(Tensor input, Scalar min=None, Scalar max=None, *, Tensor out=None)",
200-
});
201-
202-
ParsedArgs<4> parsed_args;
203-
auto r = parser.parse(args, kwargs, parsed_args);
204-
if (!r.isNone(1) && !r.isNone(2)) {
205-
if (!r.isNone(3)) {
206-
return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2), r.tensor(3)));
207-
} else {
208-
return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2)));
209-
}
210-
} else if (!r.isNone(1)) {
211-
if (!r.isNone(3)) {
212-
return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1), r.tensor(3)));
213-
} else {
214-
return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1)));
215-
}
216-
} else if (!r.isNone(2)) {
217-
if (!r.isNone(3)) {
218-
return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2), r.tensor(3)));
219-
} else {
220-
return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2)));
221-
}
222-
} else {
223-
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
224-
}
225-
Py_RETURN_NONE;
226-
END_HANDLE_TH_ERRORS
227-
}
228-
229194
static PyObject * THPVariable_from_numpy(PyObject* module, PyObject* arg)
230195
{
231196
HANDLE_TH_ERRORS
@@ -271,7 +236,6 @@ static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* k
271236
static PyMethodDef torch_functions[] = {
272237
{"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
273238
{"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
274-
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
275239
{"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
276240
{"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
277241
{"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},

tools/autograd/templates/python_torch_functions_dispatch.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,6 @@ static void maybe_initialize_cuda(const at::Type &type) {
3434
}
3535
}
3636

37-
// manual dispatch code for clamp
38-
inline Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
39-
AutoNoGIL no_gil;
40-
return self.clamp(min, max);
41-
}
42-
inline Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
43-
AutoNoGIL no_gil;
44-
return self.clamp_min(min);
45-
}
46-
inline Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
47-
AutoNoGIL no_gil;
48-
return self.clamp_max(max);
49-
}
50-
inline Tensor & dispatch_clamp(const Tensor & self, Scalar min, Scalar max, Tensor result) {
51-
AutoNoGIL no_gil;
52-
return at::clamp_out(result, self, min, max);
53-
}
54-
inline Tensor & dispatch_clamp_min(const Tensor & self, Scalar min, Tensor result) {
55-
AutoNoGIL no_gil;
56-
return at::clamp_min_out(result, self, min);
57-
}
58-
inline Tensor & dispatch_clamp_max(const Tensor & self, Scalar max, Tensor result) {
59-
AutoNoGIL no_gil;
60-
return at::clamp_max_out(result, self, max);
61-
}
62-
6337
${py_method_dispatch}
6438

6539
}} // namespace torch::autograd

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -55,80 +55,6 @@ static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
5555
END_HANDLE_TH_ERRORS
5656
}
5757

58-
static Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
59-
AutoNoGIL no_gil;
60-
DeviceGuard device_guard(self);
61-
return self.clamp(min, max);
62-
}
63-
static Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
64-
AutoNoGIL no_gil;
65-
DeviceGuard device_guard(self);
66-
return self.clamp_min(min);
67-
}
68-
static Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
69-
AutoNoGIL no_gil;
70-
DeviceGuard device_guard(self);
71-
return self.clamp_max(max);
72-
}
73-
74-
static PyObject * THPVariable_clamp(PyObject* self, PyObject* args, PyObject* kwargs)
75-
{
76-
HANDLE_TH_ERRORS
77-
static PythonArgParser parser({
78-
"clamp(Scalar min=None, Scalar max=None)",
79-
}, /*traceable=*/true);
80-
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
81-
ParsedArgs<2> parsed_args;
82-
auto r = parser.parse(args, kwargs, parsed_args);
83-
if (!r.isNone(0) && !r.isNone(1)) {
84-
return THPVariable_Wrap(dispatch_clamp(self_, r.scalar(0), r.scalar(1)));
85-
} else if (!r.isNone(0)) {
86-
return THPVariable_Wrap(dispatch_clamp_min(self_, r.scalar(0)));
87-
} else if (!r.isNone(1)) {
88-
return THPVariable_Wrap(dispatch_clamp_max(self_, r.scalar(1)));
89-
} else {
90-
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
91-
}
92-
END_HANDLE_TH_ERRORS
93-
}
94-
95-
static Tensor & dispatch_clamp_(Tensor & self, Scalar min, Scalar max) {
96-
AutoNoGIL no_gil;
97-
DeviceGuard device_guard(self);
98-
return self.clamp_(min, max);
99-
}
100-
static Tensor & dispatch_clamp_min_(Tensor & self, Scalar min) {
101-
AutoNoGIL no_gil;
102-
DeviceGuard device_guard(self);
103-
return self.clamp_min_(min);
104-
}
105-
static Tensor & dispatch_clamp_max_(Tensor & self, Scalar max) {
106-
AutoNoGIL no_gil;
107-
DeviceGuard device_guard(self);
108-
return self.clamp_max_(max);
109-
}
110-
111-
static PyObject * THPVariable_clamp_(PyObject* self, PyObject* args, PyObject* kwargs)
112-
{
113-
HANDLE_TH_ERRORS
114-
static PythonArgParser parser({
115-
"clamp_(Scalar min=None, Scalar max=None)",
116-
}, /*traceable=*/true);
117-
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
118-
ParsedArgs<2> parsed_args;
119-
auto r = parser.parse(args, kwargs, parsed_args);
120-
if (!r.isNone(0) && !r.isNone(1)) {
121-
return THPVariable_Wrap(dispatch_clamp_(self_, r.scalar(0), r.scalar(1)));
122-
} else if (!r.isNone(0)) {
123-
return THPVariable_Wrap(dispatch_clamp_min_(self_, r.scalar(0)));
124-
} else if (!r.isNone(1)) {
125-
return THPVariable_Wrap(dispatch_clamp_max_(self_, r.scalar(1)));
126-
} else {
127-
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
128-
}
129-
END_HANDLE_TH_ERRORS
130-
}
131-
13258
static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
13359
{
13460
HANDLE_TH_ERRORS
@@ -661,8 +587,6 @@ PyMethodDef variable_methods[] = {
661587
{"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL},
662588
{"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL},
663589
{"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL},
664-
{"clamp", (PyCFunction)THPVariable_clamp, METH_VARARGS | METH_KEYWORDS, NULL},
665-
{"clamp_", (PyCFunction)THPVariable_clamp_, METH_VARARGS | METH_KEYWORDS, NULL},
666590
{"contiguous", (PyCFunction)THPVariable_contiguous, METH_NOARGS, NULL},
667591
{"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
668592
{"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL},

torch/csrc/jit/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ _(namespaces, scope) \
1919
_(namespaces, namespaces) \
2020
_(prim, Assign) \
2121
_(prim, Constant) \
22+
_(prim, None) \
2223
_(prim, Drop) \
2324
_(prim, Eval) \
2425
_(prim, Expand) /* onnx */ \

torch/csrc/jit/operator.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ struct SchemaParser {
104104
case TK_FALSE:
105105
L.next();
106106
return false;
107+
case TK_NONE:
108+
L.next();
109+
return IValue();
107110
case TK_IDENT: {
108111
auto tok = L.next();
109112
auto text = tok.text();
@@ -160,11 +163,8 @@ struct SchemaParser {
160163
}
161164

162165
IValue parseTensorDefault(const SourceRange& range) {
163-
if("None" == L.expect(TK_IDENT).text()) {
164-
return at::Tensor();
165-
} else {
166-
throw ErrorReport(range) << "invalid tensor default value";
167-
}
166+
L.expect(TK_NONE);
167+
return IValue();
168168
}
169169
void parseDefaultValue(Argument& arg) {
170170
auto range = L.cur().range;

0 commit comments

Comments
 (0)