Skip to content

Commit 54577c6

Browse files
committed
feature: reuse common
1 parent a45ce12 commit 54577c6

File tree

8 files changed

+183
-164
lines changed

8 files changed

+183
-164
lines changed

bench_test.go

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -407,30 +407,34 @@ func Benchmark_realWorldInsane(b *testing.B) {
407407
b.Fatal(err)
408408
}
409409

410-
b.Run("no_cse", func(b *testing.B) {
411-
var out interface{}
412-
out, err = vm.Run(program, env)
413-
if err != nil {
414-
b.Fatal(err)
415-
}
416-
if !out.(bool) {
417-
b.Fail()
410+
b.Run("no_common_reused", func(b *testing.B) {
411+
for i := 0; i < b.N; i++ {
412+
var out interface{}
413+
out, err = vm.Run(program, env)
414+
if err != nil {
415+
b.Fatal(err)
416+
}
417+
if !out.(bool) {
418+
b.Fail()
419+
}
418420
}
419421
})
420422

421-
program, err = expr.Compile(expression, expr.Env(env), expr.AllowCommonSubExprElimination(true))
423+
program, err = expr.Compile(expression, expr.Env(env), expr.AllowReuseCommon(true))
422424
if err != nil {
423425
b.Fatal(err)
424426
}
425427

426-
b.Run("allow_cse", func(b *testing.B) {
427-
var out interface{}
428-
out, err = vm.Run(program, env)
429-
if err != nil {
430-
b.Fatal(err)
431-
}
432-
if !out.(bool) {
433-
b.Fail()
428+
b.Run("allow_common_reused", func(b *testing.B) {
429+
for i := 0; i < b.N; i++ {
430+
var out interface{}
431+
out, err = vm.Run(program, env)
432+
if err != nil {
433+
b.Fatal(err)
434+
}
435+
if !out.(bool) {
436+
b.Fail()
437+
}
434438
}
435439
})
436440
}
Lines changed: 85 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package compiler
22

33
import (
4-
"crypto/md5"
4+
"crypto/sha1"
55
"fmt"
66
"sort"
77
"strconv"
@@ -11,78 +11,78 @@ import (
1111
"github.com/antonmedv/expr/file"
1212
)
1313

14-
func (c *compiler) checkCommonSubExpr(node ast.Node) {
14+
func (c *compiler) countCommonExpr(node ast.Node) {
1515
switch n := node.(type) {
1616
case *ast.NilNode:
17-
c.checkNilNode(n)
17+
c.commonCommonNilNode(n)
1818
case *ast.IdentifierNode:
19-
c.checkIdentifierNode(n)
19+
c.countCommonIdentifierNode(n)
2020
case *ast.IntegerNode:
21-
c.checkIntegerNode(n)
21+
c.countCommonIntegerNode(n)
2222
case *ast.FloatNode:
23-
c.checkFloatNode(n)
23+
c.countCommonFloatNode(n)
2424
case *ast.BoolNode:
25-
c.checkBoolNode(n)
25+
c.countCommonBoolNode(n)
2626
case *ast.StringNode:
27-
c.checkStringNode(n)
27+
c.countCommonStringNode(n)
2828
case *ast.ConstantNode:
29-
c.checkConstantNode(n)
29+
c.countCommonConstantNode(n)
3030
case *ast.UnaryNode:
31-
c.checkUnaryNode(n)
31+
c.countCommonUnaryNode(n)
3232
case *ast.BinaryNode:
33-
c.checkBinaryNode(n)
33+
c.countCommonBinaryNode(n)
3434
case *ast.ChainNode:
35-
c.checkChainNode(n)
35+
c.countCommonChainNode(n)
3636
case *ast.MemberNode:
37-
c.checkMemberNode(n)
37+
c.countCommonMemberNode(n)
3838
case *ast.SliceNode:
39-
c.checkSliceNode(n)
39+
c.countCommonSliceNode(n)
4040
case *ast.CallNode:
4141
c.checkCallNode(n)
4242
case *ast.BuiltinNode:
43-
c.checkBuiltinNode(n)
43+
c.countCommonBuiltinNode(n)
4444
case *ast.ClosureNode:
45-
c.checkClosureNode(n)
45+
c.countCommonClosureNode(n)
4646
case *ast.PointerNode:
47-
c.checkPointerNode(n)
47+
c.countCommonPointerNode(n)
4848
case *ast.ConditionalNode:
49-
c.checkConditionalNode(n)
49+
c.countCommonConditionalNode(n)
5050
case *ast.ArrayNode:
51-
c.checkArrayNode(n)
51+
c.countCommonArrayNode(n)
5252
case *ast.MapNode:
53-
c.checkMapNode(n)
53+
c.countCommonMapNode(n)
5454
case *ast.PairNode:
55-
c.checkPairNode(n)
55+
c.countCommonPairNode(n)
5656
default:
5757
panic(fmt.Sprintf("undefined node type (%T)", node))
5858
}
5959
}
6060

61-
func (c *compiler) checkNilNode(n *ast.NilNode) {
61+
func (c *compiler) commonCommonNilNode(n *ast.NilNode) {
6262
n.SetSubExpr("nil")
6363
}
6464

65-
func (c *compiler) checkIdentifierNode(n *ast.IdentifierNode) {
65+
func (c *compiler) countCommonIdentifierNode(n *ast.IdentifierNode) {
6666
n.SetSubExpr(n.Value)
6767
}
6868

69-
func (c *compiler) checkIntegerNode(n *ast.IntegerNode) {
69+
func (c *compiler) countCommonIntegerNode(n *ast.IntegerNode) {
7070
n.SetSubExpr(strconv.FormatInt(int64(n.Value), 10))
7171
}
7272

73-
func (c *compiler) checkFloatNode(n *ast.FloatNode) {
73+
func (c *compiler) countCommonFloatNode(n *ast.FloatNode) {
7474
n.SetSubExpr(strconv.FormatFloat(n.Value, 'f', 10, 64))
7575
}
7676

77-
func (c *compiler) checkBoolNode(n *ast.BoolNode) {
77+
func (c *compiler) countCommonBoolNode(n *ast.BoolNode) {
7878
n.SetSubExpr(strconv.FormatBool(n.Value))
7979
}
8080

81-
func (c *compiler) checkStringNode(n *ast.StringNode) {
81+
func (c *compiler) countCommonStringNode(n *ast.StringNode) {
8282
n.SetSubExpr(strconv.Quote(n.Value))
8383
}
8484

85-
func (c *compiler) checkConstantNode(n *ast.ConstantNode) {
85+
func (c *compiler) countCommonConstantNode(n *ast.ConstantNode) {
8686
switch n.Value.(type) {
8787
case string:
8888
n.SetSubExpr(strconv.Quote(n.Value.(string)))
@@ -115,8 +115,8 @@ func (c *compiler) checkConstantNode(n *ast.ConstantNode) {
115115
}
116116
}
117117

118-
func (c *compiler) checkUnaryNode(n *ast.UnaryNode) {
119-
c.checkCommonSubExpr(n.Node)
118+
func (c *compiler) countCommonUnaryNode(n *ast.UnaryNode) {
119+
c.countCommonExpr(n.Node)
120120
switch n.Operator {
121121
case "+":
122122
n.SetSubExpr(n.Node.SubExpr())
@@ -127,23 +127,27 @@ func (c *compiler) checkUnaryNode(n *ast.UnaryNode) {
127127
}
128128
}
129129

130-
func (c *compiler) checkBinaryNode(n *ast.BinaryNode) {
131-
c.checkCommonSubExpr(n.Left)
132-
c.checkCommonSubExpr(n.Right)
130+
func (c *compiler) countCommonBinaryNode(n *ast.BinaryNode) {
131+
c.countCommonExpr(n.Left)
132+
c.countCommonExpr(n.Right)
133133
switch n.Operator {
134134
case "==", "!=", "and", "or", "+", "*", "||", "&&": // right / left can be swap
135135
ls := n.Left.SubExpr()
136136
rs := n.Right.SubExpr()
137137
if rs <= ls {
138138
ls, rs = rs, ls
139139
}
140-
if n.Operator == "and" {
141-
n.SetSubExpr(fmt.Sprintf("%s && %s", ls, rs))
142-
} else if n.Operator == "or" {
143-
n.SetSubExpr(fmt.Sprintf("%s || %s", ls, rs))
140+
if n.Operator == "&&" {
141+
n.SetSubExpr(fmt.Sprintf("%s and %s", ls, rs))
142+
} else if n.Operator == "||" {
143+
n.SetSubExpr(fmt.Sprintf("%s or %s", ls, rs))
144144
} else {
145145
n.SetSubExpr(fmt.Sprintf("%s %s %s", ls, n.Operator, rs))
146146
}
147+
case ">=":
148+
n.SetSubExpr(fmt.Sprintf("%s < %s", n.Right.SubExpr(), n.Left.SubExpr()))
149+
case "<=":
150+
n.SetSubExpr(fmt.Sprintf("%s > %s", n.Right.SubExpr(), n.Left.SubExpr()))
147151
case "**", "^":
148152
n.SetSubExpr(fmt.Sprintf("%s ** %s", n.Left.SubExpr(), n.Right.SubExpr()))
149153
default:
@@ -157,75 +161,75 @@ func (c *compiler) checkBinaryNode(n *ast.BinaryNode) {
157161
}
158162
}
159163

160-
func (c *compiler) checkChainNode(n *ast.ChainNode) {
161-
c.checkCommonSubExpr(n.Node)
164+
func (c *compiler) countCommonChainNode(n *ast.ChainNode) {
165+
c.countCommonExpr(n.Node)
162166
}
163167

164-
func (c *compiler) checkMemberNode(n *ast.MemberNode) {
165-
c.checkCommonSubExpr(n.Node)
166-
c.checkCommonSubExpr(n.Property)
168+
func (c *compiler) countCommonMemberNode(n *ast.MemberNode) {
169+
c.countCommonExpr(n.Node)
170+
c.countCommonExpr(n.Property)
167171
optional := ""
168172
if n.Optional {
169173
optional = "?"
170174
}
171175
n.SetSubExpr(fmt.Sprintf("%s%s.%s", n.Node.SubExpr(), optional, n.Property.SubExpr()))
172176
}
173177

174-
func (c *compiler) checkSliceNode(n *ast.SliceNode) {
175-
c.checkCommonSubExpr(n.Node)
176-
c.checkCommonSubExpr(n.To)
177-
c.checkCommonSubExpr(n.From)
178+
func (c *compiler) countCommonSliceNode(n *ast.SliceNode) {
179+
c.countCommonExpr(n.Node)
180+
c.countCommonExpr(n.To)
181+
c.countCommonExpr(n.From)
178182
n.SetSubExpr(fmt.Sprintf("%s[%s:%s]", n.Node.SubExpr(), n.From.SubExpr(), n.To.SubExpr()))
179183
}
180184

181185
func (c *compiler) checkCallNode(n *ast.CallNode) {
182186
s := make([]string, 0)
183187
for _, arg := range n.Arguments {
184-
c.checkCommonSubExpr(arg)
188+
c.countCommonExpr(arg)
185189
s = append(s, arg.SubExpr())
186190
}
187-
c.checkCommonSubExpr(n.Callee)
191+
c.countCommonExpr(n.Callee)
188192
n.SetSubExpr(fmt.Sprintf("%s(%s)", n.Callee.SubExpr(), strings.Join(s, ",")))
189193
}
190194

191-
func (c *compiler) checkBuiltinNode(n *ast.BuiltinNode) {
195+
func (c *compiler) countCommonBuiltinNode(n *ast.BuiltinNode) {
192196
s := make([]string, 0)
193197
for _, arg := range n.Arguments {
194-
c.checkCommonSubExpr(arg)
198+
c.countCommonExpr(arg)
195199
s = append(s, arg.SubExpr())
196200
}
197201
n.SetSubExpr(fmt.Sprintf("%s(%s)", n.Name, strings.Join(s, ",")))
198202
}
199203

200-
func (c *compiler) checkClosureNode(n *ast.ClosureNode) {
201-
c.checkCommonSubExpr(n.Node)
204+
func (c *compiler) countCommonClosureNode(n *ast.ClosureNode) {
205+
c.countCommonExpr(n.Node)
202206
}
203207

204-
func (c *compiler) checkPointerNode(n *ast.PointerNode) {
208+
func (c *compiler) countCommonPointerNode(n *ast.PointerNode) {
205209
// do nothing
206210
}
207211

208-
func (c *compiler) checkConditionalNode(n *ast.ConditionalNode) {
209-
c.checkCommonSubExpr(n.Cond)
210-
c.checkCommonSubExpr(n.Exp1)
211-
c.checkCommonSubExpr(n.Exp2)
212+
func (c *compiler) countCommonConditionalNode(n *ast.ConditionalNode) {
213+
c.countCommonExpr(n.Cond)
214+
c.countCommonExpr(n.Exp1)
215+
c.countCommonExpr(n.Exp2)
212216
n.SetSubExpr(fmt.Sprintf("%s ? %s : %s", n.Cond.SubExpr(), n.Exp1.SubExpr(), n.Exp2.SubExpr()))
213217
}
214218

215-
func (c *compiler) checkArrayNode(n *ast.ArrayNode) {
219+
func (c *compiler) countCommonArrayNode(n *ast.ArrayNode) {
216220
s := make([]string, 0)
217221
for _, node := range n.Nodes {
218-
c.checkCommonSubExpr(node)
222+
c.countCommonExpr(node)
219223
s = append(s, node.SubExpr())
220224
}
221225
n.SetSubExpr(fmt.Sprintf("[%s]", strings.Join(s, ",")))
222226
}
223227

224-
func (c *compiler) checkMapNode(n *ast.MapNode) {
228+
func (c *compiler) countCommonMapNode(n *ast.MapNode) {
225229
pairs := make([]*ast.PairNode, 0)
226230
for _, p := range n.Pairs {
227231
pair := p.(*ast.PairNode)
228-
c.checkPairNode(pair)
232+
c.countCommonPairNode(pair)
229233
pairs = append(pairs, pair)
230234
}
231235
sort.Slice(pairs, func(i, j int) bool {
@@ -238,38 +242,38 @@ func (c *compiler) checkMapNode(n *ast.MapNode) {
238242
n.SetSubExpr(fmt.Sprintf("{%s}", strings.Join(s, ",")))
239243
}
240244

241-
func (c *compiler) checkPairNode(n *ast.PairNode) {
242-
c.checkCommonSubExpr(n.Key)
243-
c.checkCommonSubExpr(n.Value)
245+
func (c *compiler) countCommonPairNode(n *ast.PairNode) {
246+
c.countCommonExpr(n.Key)
247+
c.countCommonExpr(n.Value)
244248
n.SetSubExpr(fmt.Sprintf("%s:%s", n.Key.SubExpr(), n.Value.SubExpr()))
245249
}
246250

247251
func (c *compiler) emitSubExpr(subExpr string, loc file.Location) {
248-
if subExpr == "" {
252+
if c.exprRecords == nil || subExpr == "" {
249253
return
250254
}
251-
hash := fmt.Sprintf("%x", md5.Sum([]byte(subExpr)))
252-
if cs, ok := c.subExprCache[hash]; !ok {
253-
c.subExprCache[hash] = &SubExprRecord{
254-
cnt: 1,
255-
}
255+
hash := fmt.Sprintf("%x", sha1.Sum([]byte(subExpr)))
256+
if cs, ok := c.exprRecords[hash]; !ok {
257+
c.exprRecords[hash] = &exprRecord{cnt: 1}
256258
} else {
257259
if cs.cnt == 1 {
258-
c.subExprUniqId += 1
259-
cs.id = c.subExprUniqId
260+
cs.id = c.commonExprInc
261+
c.commonExpr[cs.id] = subExpr
262+
c.commonExprInc += 1
260263
}
261264
cs.cnt = cs.cnt + 1
262265
}
263266
}
264267

265-
func (c *compiler) checkNeedCSE(n ast.Node) (bool, int) {
266-
needCSE := false
267-
cseUniqId := 0
268-
hash := fmt.Sprintf("%x", md5.Sum([]byte(n.SubExpr())))
269-
cs, ok := c.subExprCache[hash]
270-
if ok && cs.cnt > 1 {
271-
needCSE = true
272-
cseUniqId = cs.id
268+
func (c *compiler) needCacheCommon(n ast.Node) (bool, int) {
269+
needCacheCommon, exprUniqId := false, -1
270+
if c.exprRecords != nil {
271+
hash := fmt.Sprintf("%x", sha1.Sum([]byte(n.SubExpr())))
272+
cs, ok := c.exprRecords[hash]
273+
if ok && cs.cnt > 1 {
274+
needCacheCommon = true
275+
exprUniqId = cs.id
276+
}
273277
}
274-
return needCSE, cseUniqId
278+
return needCacheCommon, exprUniqId
275279
}

0 commit comments

Comments
 (0)