Skip to content

Commit da36d27

Browse files
nsajkoKristofferC
authored andcommitted
prevent unnecessary repeated squaring calculation (#58720)
(cherry picked from commit f61c640)
1 parent 0b32bbd commit da36d27

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

base/intfuncs.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,6 @@ function invmod(n::T) where {T<:BitInteger}
335335
end
336336

337337
# ^ for any x supporting *
338-
function to_power_type(x::Number)
339-
T = promote_type(typeof(x), typeof(x*x))
340-
convert(T, x)
341-
end
342-
to_power_type(x) = oftype(x*x, x)
343338
@noinline throw_domerr_powbysq(::Any, p) = throw(DomainError(p, LazyString(
344339
"Cannot raise an integer x to a negative power ", p, ".",
345340
"\nConvert input to float.")))
@@ -355,12 +350,23 @@ to_power_type(x) = oftype(x*x, x)
355350
"or write float(x)^", p, " or Rational.(x)^", p, ".")))
356351
# The * keyword supports `*=checked_mul` for `checked_pow`
357352
@assume_effects :terminates_locally function power_by_squaring(x_, p::Integer; mul=*)
358-
x = to_power_type(x_)
353+
x_squared_ = x_ * x_
354+
x_squared_type = typeof(x_squared_)
355+
T = if x_ isa Number
356+
promote_type(typeof(x_), x_squared_type)
357+
else
358+
x_squared_type
359+
end
360+
x = convert(T, x_)
361+
square_is_useful = mul === *
359362
if p == 1
360363
return copy(x)
361364
elseif p == 0
362365
return one(x)
363366
elseif p == 2
367+
if square_is_useful # avoid performing the same multiplication a second time when possible
368+
return convert(T, x_squared_)
369+
end
364370
return mul(x, x)
365371
elseif p < 0
366372
isone(x) && return copy(x)
@@ -369,6 +375,11 @@ to_power_type(x) = oftype(x*x, x)
369375
end
370376
t = trailing_zeros(p) + 1
371377
p >>= t
378+
if square_is_useful # avoid performing the same multiplication a second time when possible
379+
if (t -= 1) > 0
380+
x = convert(T, x_squared_)
381+
end
382+
end
372383
while (t -= 1) > 0
373384
x = mul(x, x)
374385
end

0 commit comments

Comments
 (0)