Skip to content

Commit d9b79e5

Browse files
lijunchencherrymui
authored andcommitted
cmd/compile: fix wrong complement for arm64 floating-point comparisons
Consider the following example, func test(a, b float64, x uint64) uint64 { if a < b { x = 0 } return x } func main() { fmt.Println(test(1, math.NaN(), 123)) } The output is 0, but the expectation is 123. This is because the rewrite rule (CSEL [cc] (MOVDconst [0]) y flag) => (CSEL0 [arm64Negate(cc)] y flag) converts FCMP NaN, 1 CSEL MI, 0, 123, R0 // if 1 < NaN then R0 = 0 else R0 = 123 to FCMP NaN, 1 CSEL GE, 123, 0, R0 // if 1 >= NaN then R0 = 123 else R0 = 0 But both 1 < NaN and 1 >= NaN are false. So the output is 0, not 123. The root cause is arm64Negate not handle negation of floating comparison correctly. According to the ARM manual, the meaning of MI, GE, and PL are MI: Less than GE: Greater than or equal to PL: Greater than, equal to, or unordered Because NaN cannot be compared with other numbers, the result of such comparison is unordered. So when NaN is involved, unlike integer, the result of !(a < b) is not a >= b, it is a >= b || a is NaN || b is NaN. This is exactly what PL means. We add NotLessThanF to represent PL. Then the negation of LessThanF is NotLessThanF rather than GreaterEqualF. The same reason for the other floating comparison operations. Fixes #43619 Change-Id: Ia511b0027ad067436bace9fbfd261dbeaae01bcd Reviewed-on: https://go-review.googlesource.com/c/go/+/283572 Reviewed-by: Cherry Zhang <[email protected]> Run-TryBot: Cherry Zhang <[email protected]> TryBot-Result: Go Bot <[email protected]> Trust: Keith Randall <[email protected]>
1 parent c73232d commit d9b79e5

File tree

5 files changed

+215
-27
lines changed

5 files changed

+215
-27
lines changed

src/cmd/compile/internal/arm64/ssa.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,11 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {
10541054
ssa.OpARM64LessThanF,
10551055
ssa.OpARM64LessEqualF,
10561056
ssa.OpARM64GreaterThanF,
1057-
ssa.OpARM64GreaterEqualF:
1057+
ssa.OpARM64GreaterEqualF,
1058+
ssa.OpARM64NotLessThanF,
1059+
ssa.OpARM64NotLessEqualF,
1060+
ssa.OpARM64NotGreaterThanF,
1061+
ssa.OpARM64NotGreaterEqualF:
10581062
// generate boolean values using CSET
10591063
p := s.Prog(arm64.ACSET)
10601064
p.From.Type = obj.TYPE_REG // assembler encodes conditional bits in Reg
@@ -1098,10 +1102,16 @@ var condBits = map[ssa.Op]int16{
10981102
ssa.OpARM64GreaterThanU: arm64.COND_HI,
10991103
ssa.OpARM64GreaterEqual: arm64.COND_GE,
11001104
ssa.OpARM64GreaterEqualU: arm64.COND_HS,
1101-
ssa.OpARM64LessThanF: arm64.COND_MI,
1102-
ssa.OpARM64LessEqualF: arm64.COND_LS,
1103-
ssa.OpARM64GreaterThanF: arm64.COND_GT,
1104-
ssa.OpARM64GreaterEqualF: arm64.COND_GE,
1105+
ssa.OpARM64LessThanF: arm64.COND_MI, // Less than
1106+
ssa.OpARM64LessEqualF: arm64.COND_LS, // Less than or equal to
1107+
ssa.OpARM64GreaterThanF: arm64.COND_GT, // Greater than
1108+
ssa.OpARM64GreaterEqualF: arm64.COND_GE, // Greater than or equal to
1109+
1110+
// The following condition codes have unordered to handle comparisons related to NaN.
1111+
ssa.OpARM64NotLessThanF: arm64.COND_PL, // Greater than, equal to, or unordered
1112+
ssa.OpARM64NotLessEqualF: arm64.COND_HI, // Greater than or unordered
1113+
ssa.OpARM64NotGreaterThanF: arm64.COND_LE, // Less than, equal to or unordered
1114+
ssa.OpARM64NotGreaterEqualF: arm64.COND_LT, // Less than or unordered
11051115
}
11061116

11071117
var blockJump = map[ssa.BlockKind]struct {

src/cmd/compile/internal/ssa/gen/ARM64Ops.go

+18-14
Original file line numberDiff line numberDiff line change
@@ -478,20 +478,24 @@ func init() {
478478
// pseudo-ops
479479
{name: "LoweredNilCheck", argLength: 2, reg: regInfo{inputs: []regMask{gpg}}, nilCheck: true, faultOnNilArg0: true}, // panic if arg0 is nil. arg1=mem.
480480

481-
{name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise.
482-
{name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise.
483-
{name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x<y false otherwise.
484-
{name: "LessEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x<=y false otherwise.
485-
{name: "GreaterThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x>y false otherwise.
486-
{name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise.
487-
{name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<y false otherwise.
488-
{name: "LessEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<=y false otherwise.
489-
{name: "GreaterThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>y false otherwise.
490-
{name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise.
491-
{name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y false otherwise.
492-
{name: "LessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y false otherwise.
493-
{name: "GreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y false otherwise.
494-
{name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise.
481+
{name: "Equal", argLength: 1, reg: readflags}, // bool, true flags encode x==y false otherwise.
482+
{name: "NotEqual", argLength: 1, reg: readflags}, // bool, true flags encode x!=y false otherwise.
483+
{name: "LessThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x<y false otherwise.
484+
{name: "LessEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x<=y false otherwise.
485+
{name: "GreaterThan", argLength: 1, reg: readflags}, // bool, true flags encode signed x>y false otherwise.
486+
{name: "GreaterEqual", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y false otherwise.
487+
{name: "LessThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<y false otherwise.
488+
{name: "LessEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x<=y false otherwise.
489+
{name: "GreaterThanU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>y false otherwise.
490+
{name: "GreaterEqualU", argLength: 1, reg: readflags}, // bool, true flags encode unsigned x>=y false otherwise.
491+
{name: "LessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y false otherwise.
492+
{name: "LessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y false otherwise.
493+
{name: "GreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y false otherwise.
494+
{name: "GreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y false otherwise.
495+
{name: "NotLessThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>=y || x is unordered with y, false otherwise.
496+
{name: "NotLessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y || x is unordered with y, false otherwise.
497+
{name: "NotGreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y || x is unordered with y, false otherwise.
498+
{name: "NotGreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y || x is unordered with y, false otherwise.
495499
// duffzero
496500
// arg0 = address of memory to zero
497501
// arg1 = mem

src/cmd/compile/internal/ssa/opGen.go

+40
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cmd/compile/internal/ssa/rewrite.go

+23-8
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,10 @@ func flagArg(v *Value) *Value {
974974
}
975975

976976
// arm64Negate finds the complement to an ARM64 condition code,
977-
// for example Equal -> NotEqual or LessThan -> GreaterEqual
977+
// for example !Equal -> NotEqual or !LessThan -> GreaterEqual
978978
//
979-
// TODO: add floating-point conditions
979+
// For floating point, it's more subtle because NaN is unordered. We do
980+
// !LessThanF -> NotLessThanF, the latter takes care of NaNs.
980981
func arm64Negate(op Op) Op {
981982
switch op {
982983
case OpARM64LessThan:
@@ -1000,13 +1001,21 @@ func arm64Negate(op Op) Op {
10001001
case OpARM64NotEqual:
10011002
return OpARM64Equal
10021003
case OpARM64LessThanF:
1003-
return OpARM64GreaterEqualF
1004-
case OpARM64GreaterThanF:
1005-
return OpARM64LessEqualF
1004+
return OpARM64NotLessThanF
1005+
case OpARM64NotLessThanF:
1006+
return OpARM64LessThanF
10061007
case OpARM64LessEqualF:
1008+
return OpARM64NotLessEqualF
1009+
case OpARM64NotLessEqualF:
1010+
return OpARM64LessEqualF
1011+
case OpARM64GreaterThanF:
1012+
return OpARM64NotGreaterThanF
1013+
case OpARM64NotGreaterThanF:
10071014
return OpARM64GreaterThanF
10081015
case OpARM64GreaterEqualF:
1009-
return OpARM64LessThanF
1016+
return OpARM64NotGreaterEqualF
1017+
case OpARM64NotGreaterEqualF:
1018+
return OpARM64GreaterEqualF
10101019
default:
10111020
panic("unreachable")
10121021
}
@@ -1017,8 +1026,6 @@ func arm64Negate(op Op) Op {
10171026
// that the same result would be produced if the arguments
10181027
// to the flag-generating instruction were reversed, e.g.
10191028
// (InvertFlags (CMP x y)) -> (CMP y x)
1020-
//
1021-
// TODO: add floating-point conditions
10221029
func arm64Invert(op Op) Op {
10231030
switch op {
10241031
case OpARM64LessThan:
@@ -1047,6 +1054,14 @@ func arm64Invert(op Op) Op {
10471054
return OpARM64GreaterEqualF
10481055
case OpARM64GreaterEqualF:
10491056
return OpARM64LessEqualF
1057+
case OpARM64NotLessThanF:
1058+
return OpARM64NotGreaterThanF
1059+
case OpARM64NotGreaterThanF:
1060+
return OpARM64NotLessThanF
1061+
case OpARM64NotLessEqualF:
1062+
return OpARM64NotGreaterEqualF
1063+
case OpARM64NotGreaterEqualF:
1064+
return OpARM64NotLessEqualF
10501065
default:
10511066
panic("unreachable")
10521067
}

test/fixedbugs/issue43619.go

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// run
2+
3+
// Copyright 2021 The Go Authors. All rights reserved.
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file.
6+
7+
package main
8+
9+
import (
10+
"fmt"
11+
"math"
12+
)
13+
14+
//go:noinline
15+
func fcmplt(a, b float64, x uint64) uint64 {
16+
if a < b {
17+
x = 0
18+
}
19+
return x
20+
}
21+
22+
//go:noinline
23+
func fcmple(a, b float64, x uint64) uint64 {
24+
if a <= b {
25+
x = 0
26+
}
27+
return x
28+
}
29+
30+
//go:noinline
31+
func fcmpgt(a, b float64, x uint64) uint64 {
32+
if a > b {
33+
x = 0
34+
}
35+
return x
36+
}
37+
38+
//go:noinline
39+
func fcmpge(a, b float64, x uint64) uint64 {
40+
if a >= b {
41+
x = 0
42+
}
43+
return x
44+
}
45+
46+
//go:noinline
47+
func fcmpeq(a, b float64, x uint64) uint64 {
48+
if a == b {
49+
x = 0
50+
}
51+
return x
52+
}
53+
54+
//go:noinline
55+
func fcmpne(a, b float64, x uint64) uint64 {
56+
if a != b {
57+
x = 0
58+
}
59+
return x
60+
}
61+
62+
func main() {
63+
type fn func(a, b float64, x uint64) uint64
64+
65+
type testCase struct {
66+
f fn
67+
a, b float64
68+
x, want uint64
69+
}
70+
NaN := math.NaN()
71+
for _, t := range []testCase{
72+
{fcmplt, 1.0, 1.0, 123, 123},
73+
{fcmple, 1.0, 1.0, 123, 0},
74+
{fcmpgt, 1.0, 1.0, 123, 123},
75+
{fcmpge, 1.0, 1.0, 123, 0},
76+
{fcmpeq, 1.0, 1.0, 123, 0},
77+
{fcmpne, 1.0, 1.0, 123, 123},
78+
79+
{fcmplt, 1.0, 2.0, 123, 0},
80+
{fcmple, 1.0, 2.0, 123, 0},
81+
{fcmpgt, 1.0, 2.0, 123, 123},
82+
{fcmpge, 1.0, 2.0, 123, 123},
83+
{fcmpeq, 1.0, 2.0, 123, 123},
84+
{fcmpne, 1.0, 2.0, 123, 0},
85+
86+
{fcmplt, 2.0, 1.0, 123, 123},
87+
{fcmple, 2.0, 1.0, 123, 123},
88+
{fcmpgt, 2.0, 1.0, 123, 0},
89+
{fcmpge, 2.0, 1.0, 123, 0},
90+
{fcmpeq, 2.0, 1.0, 123, 123},
91+
{fcmpne, 2.0, 1.0, 123, 0},
92+
93+
{fcmplt, 1.0, NaN, 123, 123},
94+
{fcmple, 1.0, NaN, 123, 123},
95+
{fcmpgt, 1.0, NaN, 123, 123},
96+
{fcmpge, 1.0, NaN, 123, 123},
97+
{fcmpeq, 1.0, NaN, 123, 123},
98+
{fcmpne, 1.0, NaN, 123, 0},
99+
100+
{fcmplt, NaN, 1.0, 123, 123},
101+
{fcmple, NaN, 1.0, 123, 123},
102+
{fcmpgt, NaN, 1.0, 123, 123},
103+
{fcmpge, NaN, 1.0, 123, 123},
104+
{fcmpeq, NaN, 1.0, 123, 123},
105+
{fcmpne, NaN, 1.0, 123, 0},
106+
107+
{fcmplt, NaN, NaN, 123, 123},
108+
{fcmple, NaN, NaN, 123, 123},
109+
{fcmpgt, NaN, NaN, 123, 123},
110+
{fcmpge, NaN, NaN, 123, 123},
111+
{fcmpeq, NaN, NaN, 123, 123},
112+
{fcmpne, NaN, NaN, 123, 0},
113+
} {
114+
got := t.f(t.a, t.b, t.x)
115+
if got != t.want {
116+
panic(fmt.Sprintf("want %v, got %v", t.want, got))
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)