Skip to content

Commit 503e6cc

Browse files
erifan01cherrymui
erifan01
authored andcommitted
math/big: fix the bug in assembly implementation of shlVU on arm64
For the case where the addresses of parameter z and x of the function shlVU overlap and the address of z is greater than x, x (input value) can be polluted during the calculation when the high words of x are overlapped with the low words of z (output value). Fixes #31084 Change-Id: I9bb0266a1d7856b8faa9a9b1975d6f57dece0479 Reviewed-on: https://go-review.googlesource.com/c/go/+/169780 Run-TryBot: Cherry Zhang <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Cherry Zhang <[email protected]>
1 parent dc0388c commit 503e6cc

File tree

2 files changed

+128
-47
lines changed

2 files changed

+128
-47
lines changed

src/math/big/arith_arm64.s

+59-47
Original file line numberDiff line numberDiff line change
@@ -194,87 +194,97 @@ len0:
194194
MOVD R2, c+56(FP)
195195
RET
196196

197-
198197
// func shlVU(z, x []Word, s uint) (c Word)
198+
// This implementation handles the shift operation from the high word to the low word,
199+
// which may be an error for the case where the low word of x overlaps with the high
200+
// word of z. When calling this function directly, you need to pay attention to this
201+
// situation.
199202
TEXT ·shlVU(SB),NOSPLIT,$0
200-
// Disable assembly for now - it is subtly incorrect.
201-
// See #31084 for a test that fails using this code.
202-
B ·shlVU_g(SB)
203-
204-
MOVD z+0(FP), R0
205-
MOVD z_len+8(FP), R1
203+
LDP z+0(FP), (R0, R1) // R0 = z.ptr, R1 = len(z)
206204
MOVD x+24(FP), R2
207205
MOVD s+48(FP), R3
208-
MOVD $0, R8 // in order not to affect the first element, R8 is initialized to zero
209-
MOVD $64, R4
210-
SUB R3, R4
206+
ADD R1<<3, R0 // R0 = &z[n]
207+
ADD R1<<3, R2 // R2 = &x[n]
211208
CBZ R1, len0
212209
CBZ R3, copy // if the number of shift is 0, just copy x to z
213-
214-
TBZ $0, R1, two
215-
MOVD.P 8(R2), R6
216-
LSR R4, R6, R8
217-
LSL R3, R6
218-
MOVD.P R6, 8(R0)
210+
MOVD $64, R4
211+
SUB R3, R4
212+
// handling the most significant element x[n-1]
213+
MOVD.W -8(R2), R6
214+
LSR R4, R6, R5 // return value
215+
LSL R3, R6, R8 // x[i] << s
219216
SUB $1, R1
217+
one: TBZ $0, R1, two
218+
MOVD.W -8(R2), R6
219+
LSR R4, R6, R7
220+
ORR R8, R7
221+
LSL R3, R6, R8
222+
SUB $1, R1
223+
MOVD.W R7, -8(R0)
220224
two:
221225
TBZ $1, R1, loop
222-
LDP.P 16(R2), (R6, R7)
223-
LSR R4, R6, R9
224-
LSL R3, R6
225-
ORR R8, R6
226-
LSR R4, R7, R8
226+
LDP.W -16(R2), (R6, R7)
227+
LSR R4, R7, R10
228+
ORR R8, R10
227229
LSL R3, R7
228-
ORR R9, R7
229-
STP.P (R6, R7), 16(R0)
230+
LSR R4, R6, R9
231+
ORR R7, R9
232+
LSL R3, R6, R8
230233
SUB $2, R1
234+
STP.W (R9, R10), -16(R0)
231235
loop:
232236
CBZ R1, done
233-
LDP.P 32(R2), (R10, R11)
234-
LDP -16(R2), (R12, R13)
235-
LSR R4, R10, R20
236-
LSL R3, R10
237-
ORR R8, R10 // z[i] = (x[i] << s) | (x[i-1] >> (64 - s))
238-
LSR R4, R11, R21
239-
LSL R3, R11
240-
ORR R20, R11
237+
LDP.W -32(R2), (R10, R11)
238+
LDP 16(R2), (R12, R13)
239+
LSR R4, R13, R23
240+
ORR R8, R23 // z[i] = (x[i] << s) | (x[i-1] >> (64 - s))
241+
LSL R3, R13
241242
LSR R4, R12, R22
243+
ORR R13, R22
242244
LSL R3, R12
243-
ORR R21, R12
244-
LSR R4, R13, R8
245-
LSL R3, R13
246-
ORR R22, R13
247-
STP.P (R10, R11), 32(R0)
248-
STP (R12, R13), -16(R0)
245+
LSR R4, R11, R21
246+
ORR R12, R21
247+
LSL R3, R11
248+
LSR R4, R10, R20
249+
ORR R11, R20
250+
LSL R3, R10, R8
251+
STP.W (R20, R21), -32(R0)
252+
STP (R22, R23), 16(R0)
249253
SUB $4, R1
250254
B loop
251255
done:
252-
MOVD R8, c+56(FP) // the part moved out from the last element
256+
MOVD.W R8, -8(R0) // the first element x[0]
257+
MOVD R5, c+56(FP) // the part moved out from x[n-1]
253258
RET
254259
copy:
260+
CMP R0, R2
261+
BEQ len0
255262
TBZ $0, R1, ctwo
256-
MOVD.P 8(R2), R3
257-
MOVD.P R3, 8(R0)
263+
MOVD.W -8(R2), R4
264+
MOVD.W R4, -8(R0)
258265
SUB $1, R1
259266
ctwo:
260267
TBZ $1, R1, cloop
261-
LDP.P 16(R2), (R4, R5)
262-
STP.P (R4, R5), 16(R0)
268+
LDP.W -16(R2), (R4, R5)
269+
STP.W (R4, R5), -16(R0)
263270
SUB $2, R1
264271
cloop:
265272
CBZ R1, len0
266-
LDP.P 32(R2), (R4, R5)
267-
LDP -16(R2), (R6, R7)
268-
STP.P (R4, R5), 32(R0)
269-
STP (R6, R7), -16(R0)
273+
LDP.W -32(R2), (R4, R5)
274+
LDP 16(R2), (R6, R7)
275+
STP.W (R4, R5), -32(R0)
276+
STP (R6, R7), 16(R0)
270277
SUB $4, R1
271278
B cloop
272279
len0:
273280
MOVD $0, c+56(FP)
274281
RET
275282

276-
277283
// func shrVU(z, x []Word, s uint) (c Word)
284+
// This implementation handles the shift operation from the low word to the high word,
285+
// which may be an error for the case where the high word of x overlaps with the low
286+
// word of z. When calling this function directly, you need to pay attention to this
287+
// situation.
278288
TEXT ·shrVU(SB),NOSPLIT,$0
279289
MOVD z+0(FP), R0
280290
MOVD z_len+8(FP), R1
@@ -334,6 +344,8 @@ done:
334344
MOVD R8, (R0) // deal with the last element
335345
RET
336346
copy:
347+
CMP R0, R2
348+
BEQ len0
337349
TBZ $0, R1, ctwo
338350
MOVD.P 8(R2), R3
339351
MOVD.P R3, 8(R0)

src/math/big/arith_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,75 @@ func TestFunVW(t *testing.T) {
213213
}
214214
}
215215

216+
type argVU struct {
217+
d []Word // d is a Word slice, the input parameters x and z come from this array.
218+
l uint // l is the length of the input parameters x and z.
219+
xp uint // xp is the starting position of the input parameter x, x := d[xp:xp+l].
220+
zp uint // zp is the starting position of the input parameter z, z := d[zp:zp+l].
221+
s uint // s is the shift number.
222+
r []Word // r is the expected output result z.
223+
c Word // c is the expected return value.
224+
m string // message.
225+
}
226+
227+
var argshlVU = []argVU{
228+
// test cases for shlVU
229+
{[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0}, 7, 0, 0, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "complete overlap of shlVU"},
230+
{[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0}, 7, 0, 3, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "partial overlap by half of shlVU"},
231+
{[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0, 0, 0, 0}, 7, 0, 6, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "partial overlap by 1 Word of shlVU"},
232+
{[]Word{1, _M, _M, _M, _M, _M, 3 << (_W - 2), 0, 0, 0, 0, 0, 0, 0, 0}, 7, 0, 7, 1, []Word{2, _M - 1, _M, _M, _M, _M, 1<<(_W-1) + 1}, 1, "no overlap of shlVU"},
233+
}
234+
235+
var argshrVU = []argVU{
236+
// test cases for shrVU
237+
{[]Word{0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 1, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "complete overlap of shrVU"},
238+
{[]Word{0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 4, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "partial overlap by half of shrVU"},
239+
{[]Word{0, 0, 0, 0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 7, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "partial overlap by 1 Word of shrVU"},
240+
{[]Word{0, 0, 0, 0, 0, 0, 0, 0, 3, _M, _M, _M, _M, _M, 1 << (_W - 1)}, 7, 8, 1, 1, []Word{1<<(_W-1) + 1, _M, _M, _M, _M, _M >> 1, 1 << (_W - 2)}, 1 << (_W - 1), "no overlap of shrVU"},
241+
}
242+
243+
func testShiftFunc(t *testing.T, f func(z, x []Word, s uint) Word, a argVU) {
244+
// save a.d for error message, or it will be overwritten.
245+
b := make([]Word, len(a.d))
246+
copy(b, a.d)
247+
z := a.d[a.zp : a.zp+a.l]
248+
x := a.d[a.xp : a.xp+a.l]
249+
c := f(z, x, a.s)
250+
for i, zi := range z {
251+
if zi != a.r[i] {
252+
t.Errorf("d := %v, %s(d[%d:%d], d[%d:%d], %d)\n\tgot z[%d] = %#x; want %#x", b, a.m, a.zp, a.zp+a.l, a.xp, a.xp+a.l, a.s, i, zi, a.r[i])
253+
break
254+
}
255+
}
256+
if c != a.c {
257+
t.Errorf("d := %v, %s(d[%d:%d], d[%d:%d], %d)\n\tgot c = %#x; want %#x", b, a.m, a.zp, a.zp+a.l, a.xp, a.xp+a.l, a.s, c, a.c)
258+
}
259+
}
260+
261+
func TestShiftOverlap(t *testing.T) {
262+
for _, a := range argshlVU {
263+
arg := a
264+
testShiftFunc(t, shlVU, arg)
265+
}
266+
267+
for _, a := range argshrVU {
268+
arg := a
269+
testShiftFunc(t, shrVU, arg)
270+
}
271+
}
272+
273+
func TestIssue31084(t *testing.T) {
274+
// compute 10^n via 5^n << n.
275+
const n = 165
276+
p := nat(nil).expNN(nat{5}, nat{n}, nil)
277+
p = p.shl(p, uint(n))
278+
got := string(p.utoa(10))
279+
want := "1" + strings.Repeat("0", n)
280+
if got != want {
281+
t.Errorf("shl(%v, %v)\n\tgot %s; want %s\n", p, uint(n), got, want)
282+
}
283+
}
284+
216285
func BenchmarkAddVW(b *testing.B) {
217286
for _, n := range benchSizes {
218287
if isRaceBuilder && n > 1e3 {

0 commit comments

Comments
 (0)