Skip to content

Commit ca0aeb2

Browse files
committed
Check size on transformer cache
1 parent e88c1aa commit ca0aeb2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/models/transformer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class Transformer : public EncoderOrDecoderBase {
281281
// memoization propagation (short-term)
282282
if (cache // if caching
283283
&& cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
284-
&& cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
284+
&& cache_[prefix + "_keys"]->shape() == keys->shape()) { // and the underlying shape did not change
285285
kh = cache_[prefix + "_keys"]; // then return cached tensor
286286
}
287287
else {
@@ -296,7 +296,7 @@ class Transformer : public EncoderOrDecoderBase {
296296
Expr vh;
297297
if (cache
298298
&& cache_.count(prefix + "_values") > 0
299-
&& cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
299+
&& cache_[prefix + "_values"]->shape() == values->shape()) {
300300
vh = cache_[prefix + "_values"];
301301
} else {
302302
auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform());

0 commit comments

Comments
 (0)