diff --git a/compiler/src/org.graalvm.compiler.core.test/src/org/graalvm/compiler/core/test/ea/UnsafeCompareAndSwapVirtualizationTest.java b/compiler/src/org.graalvm.compiler.core.test/src/org/graalvm/compiler/core/test/ea/UnsafeCompareAndSwapVirtualizationTest.java new file mode 100644 index 000000000000..d6673619ab68 --- /dev/null +++ b/compiler/src/org.graalvm.compiler.core.test/src/org/graalvm/compiler/core/test/ea/UnsafeCompareAndSwapVirtualizationTest.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2011, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package org.graalvm.compiler.core.test.ea; + +import java.util.concurrent.atomic.AtomicReference; + +import org.graalvm.compiler.nodes.java.LogicCompareAndSwapNode; +import org.junit.Test; + +import jdk.vm.ci.meta.JavaConstant; + +public class UnsafeCompareAndSwapVirtualizationTest extends EATestBase { + + private static Object obj1 = new Object(); + private static Object obj2 = new Object(); + + public static boolean bothVirtualNoMatch() { + AtomicReference a = new AtomicReference<>(); + return a.compareAndSet(new Object(), new Object()); + } + + @Test + public void bothVirtualNoMatchTest() { + testEscapeAnalysis("bothVirtualNoMatch", JavaConstant.INT_0, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean bothVirtualMatch() { + Object expect = new Object(); + AtomicReference a = new AtomicReference<>(expect); + return a.compareAndSet(expect, new Object()); + } + + @Test + public void bothVirtualMatchTest() { + testEscapeAnalysis("bothVirtualMatch", JavaConstant.INT_1, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean expectedVirtualMatch() { + Object o = new Object(); + AtomicReference a = new AtomicReference<>(o); + return a.compareAndSet(o, obj1); + } + + @Test + public void expectedVirtualMatchTest() { + testEscapeAnalysis("expectedVirtualMatch", null, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean expectedVirtualNoMatch() { + Object o = new Object(); + AtomicReference a = new AtomicReference<>(); + return a.compareAndSet(o, obj1); + } + + @Test + public void expectedVirtualNoMatchTest() { + testEscapeAnalysis("expectedVirtualNoMatch", null, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean bothNonVirtualNoMatch() { + AtomicReference a = new AtomicReference<>(); + return a.compareAndSet(obj1, obj2); + } + + @Test + public void bothNonVirtualNoMatchTest() { + testEscapeAnalysis("bothNonVirtualNoMatch", null, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean bothNonVirtualMatch() { + AtomicReference a = new AtomicReference<>(obj1); + return a.compareAndSet(obj1, obj2); + } + + @Test + public void bothNonVirtualMatchTest() { + testEscapeAnalysis("bothNonVirtualMatch", null, true); + assertTrue(graph.getNodes(LogicCompareAndSwapNode.TYPE).isEmpty()); + } + + public static boolean onlyInitialValueVirtualMatch() { + AtomicReference a = new AtomicReference<>(new Object()); + return a.compareAndSet(obj1, obj2); + } +} diff --git a/compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/java/UnsafeCompareAndSwapNode.java b/compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/java/UnsafeCompareAndSwapNode.java index 35433a4d498a..a5dcddbef8c6 100644 --- a/compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/java/UnsafeCompareAndSwapNode.java +++ b/compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/java/UnsafeCompareAndSwapNode.java @@ -24,6 +24,7 @@ */ package org.graalvm.compiler.nodes.java; +import static org.graalvm.compiler.core.common.calc.CanonicalCondition.EQ; import static org.graalvm.compiler.nodeinfo.InputType.Memory; import static org.graalvm.compiler.nodeinfo.InputType.Value; import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_8; @@ -32,22 +33,32 @@ import org.graalvm.compiler.core.common.type.StampFactory; import org.graalvm.compiler.graph.NodeClass; import org.graalvm.compiler.nodeinfo.NodeInfo; +import org.graalvm.compiler.nodes.ConstantNode; +import org.graalvm.compiler.nodes.LogicConstantNode; +import org.graalvm.compiler.nodes.LogicNode; import org.graalvm.compiler.nodes.NodeView; import org.graalvm.compiler.nodes.ValueNode; +import org.graalvm.compiler.nodes.calc.CompareNode; +import org.graalvm.compiler.nodes.calc.ConditionalNode; import org.graalvm.compiler.nodes.memory.AbstractMemoryCheckpoint; import org.graalvm.compiler.nodes.memory.MemoryCheckpoint; import org.graalvm.compiler.nodes.spi.Lowerable; import org.graalvm.compiler.nodes.spi.LoweringTool; +import org.graalvm.compiler.nodes.spi.Virtualizable; +import org.graalvm.compiler.nodes.spi.VirtualizerTool; +import org.graalvm.compiler.nodes.virtual.VirtualInstanceNode; +import org.graalvm.compiler.nodes.virtual.VirtualObjectNode; import org.graalvm.word.LocationIdentity; import jdk.vm.ci.meta.JavaKind; +import jdk.vm.ci.meta.ResolvedJavaField; /** * Represents an atomic compare-and-swap operation. The result is a boolean that contains whether * the value matched the expected value. */ @NodeInfo(allowedUsageTypes = {Value, Memory}, cycles = CYCLES_8, size = SIZE_8) -public final class UnsafeCompareAndSwapNode extends AbstractMemoryCheckpoint implements Lowerable, MemoryCheckpoint.Single { +public final class UnsafeCompareAndSwapNode extends AbstractMemoryCheckpoint implements Lowerable, MemoryCheckpoint.Single, Virtualizable { public static final NodeClass TYPE = NodeClass.create(UnsafeCompareAndSwapNode.class); @Input ValueNode object; @@ -98,4 +109,50 @@ public LocationIdentity getLocationIdentity() { public void lower(LoweringTool tool) { tool.getLowerer().lower(this, tool); } + + @Override + public void virtualize(VirtualizerTool tool) { + ValueNode alias = tool.getAlias(object); + if (alias instanceof VirtualInstanceNode && offset.isConstant()) { + + VirtualInstanceNode obj = (VirtualInstanceNode) alias; + + int index = resolveFieldIndex(obj); + + if (index >= 0) { + + ValueNode currentValue = tool.getEntry(obj, index); + ValueNode expectedAlias = tool.getAlias(this.expected); + ValueNode newValueAlias = tool.getAlias(this.newValue); + + LogicNode equalsNode = CompareNode.createCompareNode(EQ, expectedAlias, currentValue, tool.getConstantReflectionProvider(), NodeView.DEFAULT); + if (equalsNode instanceof LogicConstantNode) { + + boolean equals = ((LogicConstantNode) equalsNode).getValue(); + if (equals) { + tool.setVirtualEntry(obj, index, newValueAlias); + } + tool.replaceWith(ConstantNode.forBoolean(equals)); + + } else if (!(currentValue instanceof VirtualInstanceNode) && + !(expectedAlias instanceof VirtualObjectNode) && + !(newValueAlias instanceof VirtualObjectNode)) { + ValueNode fieldValue = ConditionalNode.create(equalsNode, newValueAlias, currentValue, NodeView.DEFAULT); + ValueNode result = ConditionalNode.create(equalsNode, ConstantNode.forBoolean(true), ConstantNode.forBoolean(false), NodeView.DEFAULT); + + tool.setVirtualEntry(obj, index, fieldValue); + tool.addNode(equalsNode); + tool.addNode(fieldValue); + tool.addNode(result); + tool.replaceWith(result); + } + } + } + } + + private int resolveFieldIndex(VirtualInstanceNode obj) { + long fieldOffset = offset.asJavaConstant().asLong(); + ResolvedJavaField field = obj.type().findInstanceFieldWithOffset(fieldOffset, expected.getStackKind()); + return obj.fieldIndex(field); + } }