Skip to content

Commit e30ee13

Browse files
gchananfacebook-github-bot
authored andcommitted
Small fixups for enabling zero size dims. (#9724)
Summary: 1) Properly test cpu for alpha/beta addmm cases. 2) Unsqueeze on empty no longer throws an exception. Pull Request resolved: pytorch/pytorch#9724 Reviewed By: ezyang Differential Revision: D8958513 Pulled By: gchanan fbshipit-source-id: 6ce2ec4a47201f9b225b8c52354144ace43e9e09
1 parent 1d7ff8a commit e30ee13

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

aten/src/TH/generic/THTensorMath.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,16 @@ void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
859859
THTensor_(free)(cmat);
860860
}
861861

862+
// In gemv (x,0).mv(0) does not
863+
// handle beta, whereas gemm does for case where (x,0).mm(0,y).
864+
if (vec->size(0) == 0 && mat->size(0) != 0) {
865+
if (beta == 0) {
866+
THTensor_(zero)(r_);
867+
} else if (beta != 1) {
868+
THTensor_(mul)(r_, r_, beta);
869+
}
870+
}
871+
862872
#undef LDA_COND
863873
}
864874

aten/src/THC/generic/THCTensorMathBlas.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
116116
THCTensor_(free)(state, cmat);
117117
}
118118

119-
// cublasSgemv, cublasDgemv have a bug where (x,0).mv(0) does not
119+
// In cublasSgemv, cublasDgemv (x,0).mv(0) does not
120120
// handle beta, whereas cublasSgemm, cublasDgemm do for case where (x,0).mm(0,y).
121121
if (vec->size(0) == 0 && mat->size(0) != 0) {
122122
if(THCNumerics<real>::eq(beta, ScalarConvert<int, real>::to(0))) {

0 commit comments

Comments
 (0)