Skip to content

Commit 22a194d

Browse files
authored
Merge pull request #25 from iotamudelta/master
Merge from upstream
2 parents 6af217c + 907c9f1 commit 22a194d

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

torch/csrc/Module.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,26 @@ PyObject *THPModule_hasDistributed(PyObject *_unused)
282282
#endif
283283
}
284284

285+
void DLPack_Capsule_Destructor(PyObject* data) {
286+
HANDLE_TH_ERRORS
287+
DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
288+
if (dlMTensor) {
289+
// the dlMTensor has not been consumed, call deleter ourselves
290+
dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
291+
} else {
292+
// the dlMTensor has been consumed
293+
// PyCapsule_GetPointer has set an error indicator
294+
PyErr_Clear();
295+
}
296+
END_HANDLE_TH_ERRORS_RET()
297+
}
298+
285299
PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data)
286300
{
287301
HANDLE_TH_ERRORS
288302
THPUtils_assert(THPVariable_Check(data), "data must be a Tensor");
289303
DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_UnpackData(data));
290-
return PyCapsule_New(dlMTensor, "dltensor", NULL);
304+
return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
291305
END_HANDLE_TH_ERRORS
292306
}
293307

torch/nn/utils/clip_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2):
2929
total_norm = 0
3030
for p in parameters:
3131
param_norm = p.grad.data.norm(norm_type)
32-
total_norm += param_norm ** norm_type
32+
total_norm += param_norm.item() ** norm_type
3333
total_norm = total_norm ** (1. / norm_type)
3434
clip_coef = max_norm / (total_norm + 1e-6)
3535
if clip_coef < 1:
3636
for p in parameters:
37-
p.grad.data.mul_(clip_coef.item())
37+
p.grad.data.mul_(clip_coef)
3838
return total_norm
3939

4040

0 commit comments

Comments
 (0)