Skip to content

Commit 895a7b6

Browse files
committed
8342967: Lambda deduplication fails with non-metafactory BSMs and mismatched local variables names
Reviewed-by: mcimadamore
1 parent b41d713 commit 895a7b6

File tree

6 files changed

+202
-113
lines changed

6 files changed

+202
-113
lines changed

src/jdk.compiler/share/classes/com/sun/tools/javac/comp/LambdaToMethod.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class DedupedLambda {
217217
public int hashCode() {
218218
int hashCode = this.hashCode;
219219
if (hashCode == 0) {
220-
this.hashCode = hashCode = TreeHasher.hash(tree, symbol.params());
220+
this.hashCode = hashCode = TreeHasher.hash(types, tree, symbol.params());
221221
}
222222
return hashCode;
223223
}
@@ -226,7 +226,7 @@ public int hashCode() {
226226
public boolean equals(Object o) {
227227
return (o instanceof DedupedLambda dedupedLambda)
228228
&& types.isSameType(symbol.asType(), dedupedLambda.symbol.asType())
229-
&& new TreeDiffer(symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree);
229+
&& new TreeDiffer(types, symbol.params(), dedupedLambda.symbol.params()).scan(tree, dedupedLambda.tree);
230230
}
231231
}
232232

src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TransPatterns.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ public void resolve(VarSymbol commonBinding,
993993
!currentNullable &&
994994
!previousCompletesNormally &&
995995
!currentCompletesNormally &&
996-
new TreeDiffer(List.of(commonBinding), List.of(currentBinding))
996+
new TreeDiffer(types, List.of(commonBinding), List.of(currentBinding))
997997
.scan(commonNestedExpression, currentNestedExpression)) {
998998
accummulator.add(c.head);
999999
} else {

src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeDiffer.java

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
import com.sun.tools.javac.code.Flags;
3030
import com.sun.tools.javac.code.Symbol;
31+
import com.sun.tools.javac.code.TypeTag;
32+
import com.sun.tools.javac.code.Types;
33+
import com.sun.tools.javac.jvm.PoolConstant;
3134
import com.sun.tools.javac.tree.JCTree;
3235
import com.sun.tools.javac.tree.JCTree.JCAnnotatedType;
3336
import com.sun.tools.javac.tree.JCTree.JCAnnotation;
@@ -107,10 +110,10 @@
107110

108111
/** A visitor that compares two lambda bodies for structural equality. */
109112
public class TreeDiffer extends TreeScanner {
110-
111-
public TreeDiffer(
112-
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
113+
public TreeDiffer(Types types,
114+
Collection<? extends Symbol> symbols, Collection<? extends Symbol> otherSymbols) {
113115
this.equiv = equiv(symbols, otherSymbols);
116+
this.types = types;
114117
}
115118

116119
private static Map<Symbol, Symbol> equiv(
@@ -127,6 +130,7 @@ private static Map<Symbol, Symbol> equiv(
127130
private JCTree parameter;
128131
private boolean result;
129132
private Map<Symbol, Symbol> equiv = new HashMap<>();
133+
final Types types;
130134

131135
public boolean scan(JCTree tree, JCTree parameter) {
132136
if (tree == null || parameter == null) {
@@ -197,13 +201,24 @@ public void visitIdent(JCIdent tree) {
197201
return;
198202
}
199203
}
200-
result = tree.sym == that.sym;
204+
result = scanSymbol(symbol, otherSymbol);
205+
}
206+
207+
private boolean scanSymbol(Symbol symbol, Symbol otherSymbol) {
208+
if (symbol instanceof PoolConstant.Dynamic dms && otherSymbol instanceof PoolConstant.Dynamic other_dms) {
209+
return dms.bsmKey(types).equals(other_dms.bsmKey(types));
210+
}
211+
else {
212+
return symbol == otherSymbol;
213+
}
201214
}
202215

203216
@Override
204217
public void visitSelect(JCFieldAccess tree) {
205218
JCFieldAccess that = (JCFieldAccess) parameter;
206-
result = scan(tree.selected, that.selected) && tree.sym == that.sym;
219+
220+
result = scan(tree.selected, that.selected) &&
221+
scanSymbol(tree.sym, that.sym);
207222
}
208223

209224
@Override
@@ -328,14 +343,7 @@ public void visitCatch(JCCatch tree) {
328343

329344
@Override
330345
public void visitClassDef(JCClassDecl tree) {
331-
JCClassDecl that = (JCClassDecl) parameter;
332-
result =
333-
scan(tree.mods, that.mods)
334-
&& tree.name == that.name
335-
&& scan(tree.typarams, that.typarams)
336-
&& scan(tree.extending, that.extending)
337-
&& scan(tree.implementing, that.implementing)
338-
&& scan(tree.defs, that.defs);
346+
result = false;
339347
}
340348

341349
@Override
@@ -667,14 +675,18 @@ public void visitVarDef(JCVariableDecl tree) {
667675
JCVariableDecl that = (JCVariableDecl) parameter;
668676
result =
669677
scan(tree.mods, that.mods)
670-
&& tree.name == that.name
671678
&& scan(tree.nameexpr, that.nameexpr)
672679
&& scan(tree.vartype, that.vartype)
673680
&& scan(tree.init, that.init);
674-
if (!result) {
675-
return;
681+
682+
if (tree.sym.owner.type.hasTag(TypeTag.CLASS)) {
683+
// field names are important!
684+
result &= tree.name == that.name;
685+
}
686+
687+
if (result) {
688+
equiv.put(tree.sym, that.sym);
676689
}
677-
equiv.put(tree.sym, that.sym);
678690
}
679691

680692
@Override

src/jdk.compiler/share/classes/com/sun/tools/javac/comp/TreeHasher.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
package com.sun.tools.javac.comp;
2828

2929
import com.sun.tools.javac.code.Symbol;
30+
import com.sun.tools.javac.code.Types;
31+
import com.sun.tools.javac.jvm.PoolConstant;
3032
import com.sun.tools.javac.tree.JCTree;
33+
import com.sun.tools.javac.tree.JCTree.JCClassDecl;
3134
import com.sun.tools.javac.tree.JCTree.JCFieldAccess;
3235
import com.sun.tools.javac.tree.JCTree.JCIdent;
3336
import com.sun.tools.javac.tree.JCTree.JCLiteral;
@@ -43,19 +46,21 @@
4346
public class TreeHasher extends TreeScanner {
4447

4548
private final Map<Symbol, Integer> symbolHashes;
49+
private final Types types;
4650
private int result = 17;
4751

48-
public TreeHasher(Map<Symbol, Integer> symbolHashes) {
52+
public TreeHasher(Types types, Map<Symbol, Integer> symbolHashes) {
4953
this.symbolHashes = Objects.requireNonNull(symbolHashes);
54+
this.types = types;
5055
}
5156

52-
public static int hash(JCTree tree, Collection<? extends Symbol> symbols) {
57+
public static int hash(Types types, JCTree tree, Collection<? extends Symbol> symbols) {
5358
if (tree == null) {
5459
return 0;
5560
}
5661
Map<Symbol, Integer> symbolHashes = new HashMap<>();
5762
symbols.forEach(s -> symbolHashes.put(s, symbolHashes.size()));
58-
TreeHasher hasher = new TreeHasher(symbolHashes);
63+
TreeHasher hasher = new TreeHasher(types, symbolHashes);
5964
tree.accept(hasher);
6065
return hasher.result;
6166
}
@@ -87,6 +92,11 @@ public void visitLiteral(JCLiteral tree) {
8792
super.visitLiteral(tree);
8893
}
8994

95+
@Override
96+
public void visitClassDef(JCClassDecl tree) {
97+
hash(tree.sym);
98+
}
99+
90100
@Override
91101
public void visitIdent(JCIdent tree) {
92102
Symbol sym = tree.sym;
@@ -97,15 +107,23 @@ public void visitIdent(JCIdent tree) {
97107
return;
98108
}
99109
}
100-
hash(sym);
110+
hashSymbol(sym);
101111
}
102112

103113
@Override
104114
public void visitSelect(JCFieldAccess tree) {
105-
hash(tree.sym);
115+
hashSymbol(tree.sym);
106116
super.visitSelect(tree);
107117
}
108118

119+
private void hashSymbol(Symbol sym) {
120+
if (sym instanceof PoolConstant.Dynamic dynamic) {
121+
hash(dynamic.bsmKey(types));
122+
} else {
123+
hash(sym);
124+
}
125+
}
126+
109127
@Override
110128
public void visitVarDef(JCVariableDecl tree) {
111129
symbolHashes.computeIfAbsent(tree.sym, k -> symbolHashes.size());

test/langtools/tools/javac/lambda/deduplication/Deduplication.java

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,52 +29,54 @@
2929
import java.util.function.Supplier;
3030

3131
public class Deduplication {
32+
void groupEquals(Object... xs) {}
33+
void groupNotEquals(Object... xs) {}
3234
void group(Object... xs) {}
3335

3436
void test() {
3537

36-
group(
38+
groupEquals(
3739
(Runnable) () -> { ( (Runnable) () -> {} ).run(); },
3840
(Runnable) () -> { ( (Runnable) () -> {} ).run(); }
3941
);
4042

41-
group(
43+
groupEquals(
4244
(Runnable) () -> { Deduplication.class.toString(); },
4345
(Runnable) () -> { Deduplication.class.toString(); }
4446
);
4547

46-
group(
48+
groupEquals(
4749
(Runnable) () -> { Integer[].class.toString(); },
4850
(Runnable) () -> { Integer[].class.toString(); }
4951
);
5052

51-
group(
53+
groupEquals(
5254
(Runnable) () -> { char.class.toString(); },
5355
(Runnable) () -> { char.class.toString(); }
5456
);
5557

56-
group(
58+
groupEquals(
5759
(Runnable) () -> { Void.class.toString(); },
5860
(Runnable) () -> { Void.class.toString(); }
5961
);
6062

61-
group(
63+
groupEquals(
6264
(Runnable) () -> { void.class.toString(); },
6365
(Runnable) () -> { void.class.toString(); }
6466
);
6567

66-
group((Function<String, Integer>) x -> x.hashCode());
67-
group((Function<Object, Integer>) x -> x.hashCode());
68+
groupEquals((Function<String, Integer>) x -> x.hashCode());
69+
groupEquals((Function<Object, Integer>) x -> x.hashCode());
6870

6971
{
7072
int x = 1;
71-
group((Supplier<Integer>) () -> x + 1);
73+
groupEquals((Supplier<Integer>) () -> x + 1);
7274
}
7375
{
7476
int x = 1;
75-
group((Supplier<Integer>) () -> x + 1);
77+
groupEquals((Supplier<Integer>) () -> x + 1);
7678
}
77-
group(
79+
groupEquals(
7880
(BiFunction<Integer, Integer, ?>) (x, y) -> x + ((y)),
7981
(BiFunction<Integer, Integer, ?>) (x, y) -> x + (y),
8082
(BiFunction<Integer, Integer, ?>) (x, y) -> x + y,
@@ -85,29 +87,29 @@ void test() {
8587
(BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + (y),
8688
(BiFunction<Integer, Integer, ?>) (x, y) -> ((x)) + y);
8789

88-
group(
90+
groupEquals(
8991
(Function<Integer, Integer>) x -> x + (1 + 2 + 3),
9092
(Function<Integer, Integer>) x -> x + 6);
9193

92-
group((Function<Integer, Integer>) x -> x + 1, (Function<Integer, Integer>) y -> y + 1);
94+
groupEquals((Function<Integer, Integer>) x -> x + 1, (Function<Integer, Integer>) y -> y + 1);
9395

94-
group((Consumer<Integer>) x -> this.f(), (Consumer<Integer>) x -> this.f());
96+
groupEquals((Consumer<Integer>) x -> this.f(), (Consumer<Integer>) x -> this.f());
9597

96-
group((Consumer<Integer>) y -> this.g());
98+
groupEquals((Consumer<Integer>) y -> this.g());
9799

98-
group((Consumer<Integer>) x -> f(), (Consumer<Integer>) x -> f());
100+
groupEquals((Consumer<Integer>) x -> f(), (Consumer<Integer>) x -> f());
99101

100-
group((Consumer<Integer>) y -> g());
102+
groupEquals((Consumer<Integer>) y -> g());
101103

102-
group((Function<Integer, Integer>) x -> this.i, (Function<Integer, Integer>) x -> this.i);
104+
groupEquals((Function<Integer, Integer>) x -> this.i, (Function<Integer, Integer>) x -> this.i);
103105

104-
group((Function<Integer, Integer>) y -> this.j);
106+
groupEquals((Function<Integer, Integer>) y -> this.j);
105107

106-
group((Function<Integer, Integer>) x -> i, (Function<Integer, Integer>) x -> i);
108+
groupEquals((Function<Integer, Integer>) x -> i, (Function<Integer, Integer>) x -> i);
107109

108-
group((Function<Integer, Integer>) y -> j);
110+
groupEquals((Function<Integer, Integer>) y -> j);
109111

110-
group(
112+
groupEquals(
111113
(Function<Integer, Integer>)
112114
y -> {
113115
while (true) {
@@ -123,7 +125,7 @@ void test() {
123125
return 42;
124126
});
125127

126-
group(
128+
groupEquals(
127129
(Function<Integer, Integer>)
128130
x -> {
129131
int y = x;
@@ -135,13 +137,13 @@ void test() {
135137
return y;
136138
});
137139

138-
group(
140+
groupEquals(
139141
(Function<Integer, Integer>)
140142
x -> {
141143
int y = 0, z = x;
142144
return y;
143145
});
144-
group(
146+
groupEquals(
145147
(Function<Integer, Integer>)
146148
x -> {
147149
int y = 0, z = x;
@@ -154,24 +156,41 @@ class Local {
154156
void f() {}
155157

156158
{
157-
group((Function<Integer, Integer>) x -> this.i);
158-
group((Consumer<Integer>) x -> this.f());
159-
group((Function<Integer, Integer>) x -> Deduplication.this.i);
160-
group((Consumer<Integer>) x -> Deduplication.this.f());
159+
groupEquals((Function<Integer, Integer>) x -> this.i);
160+
groupEquals((Consumer<Integer>) x -> this.f());
161+
groupEquals((Function<Integer, Integer>) x -> Deduplication.this.i);
162+
groupEquals((Consumer<Integer>) x -> Deduplication.this.f());
161163
}
162164
}
163165

164-
group((Function<Integer, Integer>) x -> switch (x) { default: yield x; },
166+
groupEquals((Function<Integer, Integer>) x -> switch (x) { default: yield x; },
165167
(Function<Integer, Integer>) x -> switch (x) { default: yield x; });
166168

167-
group((Function<Object, Integer>) x -> x instanceof Integer i ? i : -1,
169+
groupEquals((Function<Object, Integer>) x -> x instanceof Integer i ? i : -1,
168170
(Function<Object, Integer>) x -> x instanceof Integer i ? i : -1);
169171

170-
group((Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1,
172+
groupEquals((Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1,
171173
(Function<Object, Integer>) x -> x instanceof R(var i1, var i2) ? i1 : -1 );
172174

173-
group((Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1,
175+
groupEquals((Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1,
174176
(Function<Object, Integer>) x -> x instanceof R(Integer i1, int i2) ? i2 : -1 );
177+
178+
groupEquals((Function<Object, Integer>) x -> x instanceof int i2 ? i2 : -1,
179+
(Function<Object, Integer>) x -> x instanceof int i2 ? i2 : -1);
180+
181+
groupEquals((Function<Object, Integer>) x -> switch (x) { case String s -> s.length(); default -> -1; },
182+
(Function<Object, Integer>) x -> switch (x) { case String s -> s.length(); default -> -1; });
183+
184+
groupEquals((Function<Object, Integer>) x -> {
185+
int y1 = -1;
186+
return y1;
187+
},
188+
(Function<Object, Integer>) x -> {
189+
int y2 = -1;
190+
return y2;
191+
});
192+
193+
groupNotEquals((Function<Object, Integer>) x -> {class C {} new C(); return 42; }, (Function<Object, Integer>) x -> {class C {} new C(); return 42; });
175194
}
176195

177196
void f() {}

0 commit comments

Comments
 (0)