diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java index bfaf075a0ab..48e0d1567c2 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/PreparedStatement.java @@ -117,6 +117,20 @@ public interface PreparedStatement { @NonNull ColumnDefinitions getResultSetDefinitions(); + /** + * Informs if this is an LWT query. + * + *

Not guaranteed to return true for LWT queries (but guaranteed to return false for non-LWT + * ones). It can happen for several reasons, for example: using Cassandra instead of Scylla, using + * too old Scylla version, future changes in driver allowing channels to be created without + * sending OPTIONS request. + * + *

More information about LWT: + * + * @see Docs about LWT + */ + boolean isLWT(); + /** * Updates {@link #getResultMetadataId()} and {@link #getResultSetDefinitions()} atomically. * diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index dcfd1420b53..78bd65aab3a 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -13,6 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +/* + * Copyright (C) 2022 ScyllaDB + * + * Modified by ScyllaDB + */ package com.datastax.oss.driver.api.core.cql; import com.datastax.oss.driver.api.core.ConsistencyLevel; @@ -516,6 +522,20 @@ default SelfT setNowInSeconds(int nowInSeconds) { return (SelfT) this; } + /** + * Informs if this is a prepared LWT query. + * + *

Not guaranteed to return true for prepared LWT queries (but guaranteed to return false for + * non-LWT ones). It can happen for several reasons, for example: using Cassandra instead of + * Scylla, using too old Scylla version, future changes in driver allowing channels to be created + * without sending OPTIONS request. + * + *

More information about LWT: + * + * @see Docs about LWT + */ + boolean isLWT(); + /** * Calculates the approximate size in bytes that the statement will have when encoded. * diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/DriverChannel.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/DriverChannel.java index 0284124a6f2..e1b1005cc14 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/DriverChannel.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/DriverChannel.java @@ -29,6 +29,7 @@ import com.datastax.oss.driver.internal.core.adminrequest.AdminRequestHandler; import com.datastax.oss.driver.internal.core.adminrequest.ThrottledAdminRequestHandler; import com.datastax.oss.driver.internal.core.pool.ChannelPool; +import com.datastax.oss.driver.internal.core.protocol.LwtInfo; import com.datastax.oss.driver.internal.core.protocol.ShardingInfo; import com.datastax.oss.driver.internal.core.protocol.ShardingInfo.ConnectionShardingInfo; import com.datastax.oss.driver.internal.core.session.DefaultSession; @@ -60,6 +61,7 @@ public class DriverChannel { AttributeKey.newInstance("options"); static final AttributeKey SHARDING_INFO_KEY = AttributeKey.newInstance("sharding_info"); + static final AttributeKey LWT_INFO_KEY = AttributeKey.newInstance("lwt_info"); @SuppressWarnings("RedundantStringConstructorCall") static final Object GRACEFUL_CLOSE_MESSAGE = new String("GRACEFUL_CLOSE_MESSAGE"); @@ -154,6 +156,10 @@ public ShardingInfo getShardingInfo() { : null; } + public LwtInfo getLwtInfo() { + return channel.attr(LWT_INFO_KEY).get(); + } + /** * @return the number of available stream ids on the channel; more precisely, this is the number * of {@link #preAcquireId()} calls for which the id has not been released yet. This is used diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ProtocolInitHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ProtocolInitHandler.java index d95aef43e8d..9570e960536 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ProtocolInitHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ProtocolInitHandler.java @@ -21,6 +21,8 @@ */ package com.datastax.oss.driver.internal.core.channel; +import static com.datastax.oss.driver.internal.core.channel.DriverChannel.LWT_INFO_KEY; + import com.datastax.oss.driver.api.core.DefaultProtocolVersion; import com.datastax.oss.driver.api.core.InvalidKeyspaceException; import com.datastax.oss.driver.api.core.ProtocolVersion; @@ -36,6 +38,7 @@ import com.datastax.oss.driver.internal.core.context.InternalDriverContext; import com.datastax.oss.driver.internal.core.protocol.BytesToSegmentDecoder; import com.datastax.oss.driver.internal.core.protocol.FrameToSegmentEncoder; +import com.datastax.oss.driver.internal.core.protocol.LwtInfo; import com.datastax.oss.driver.internal.core.protocol.SegmentToBytesEncoder; import com.datastax.oss.driver.internal.core.protocol.SegmentToFrameDecoder; import com.datastax.oss.driver.internal.core.protocol.ShardingInfo; @@ -61,7 +64,9 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import java.nio.ByteBuffer; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import net.jcip.annotations.NotThreadSafe; import org.slf4j.Logger; @@ -88,6 +93,7 @@ class ProtocolInitHandler extends ConnectInitHandler { private String logPrefix; private ChannelHandlerContext ctx; private final boolean querySupportedOptions; + private LwtInfo lwtInfo; /** * @param querySupportedOptions whether to send OPTIONS as the first message, to request which @@ -181,7 +187,11 @@ Message getRequest() { case OPTIONS: return request = Options.INSTANCE; case STARTUP: - return request = new Startup(context.getStartupOptions()); + Map startupOptions = new HashMap<>(context.getStartupOptions()); + if (lwtInfo != null) { + lwtInfo.addOption(startupOptions); + } + return request = new Startup(startupOptions); case GET_CLUSTER_NAME: return request = CLUSTER_NAME_QUERY; case SET_KEYSPACE: @@ -212,9 +222,13 @@ void onResponse(Message response) { if (step == Step.OPTIONS && response instanceof Supported) { channel.attr(DriverChannel.OPTIONS_KEY).set(((Supported) response).options); Supported res = (Supported) response; - ConnectionShardingInfo info = ShardingInfo.parseShardingInfo(res.options); - if (info != null) { - channel.attr(DriverChannel.SHARDING_INFO_KEY).set(info); + ConnectionShardingInfo shardingInfo = ShardingInfo.parseShardingInfo(res.options); + if (shardingInfo != null) { + channel.attr(DriverChannel.SHARDING_INFO_KEY).set(shardingInfo); + } + lwtInfo = LwtInfo.parseLwtInfo(res.options); + if (lwtInfo != null) { + channel.attr(LWT_INFO_KEY).set(lwtInfo); } step = Step.STARTUP; send(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java index 3aaba7b7aec..763c1a8ffb5 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/Conversions.java @@ -71,6 +71,7 @@ import com.datastax.oss.driver.internal.core.context.InternalDriverContext; import com.datastax.oss.driver.internal.core.data.ValuesHelper; import com.datastax.oss.driver.internal.core.metadata.PartitionerFactory; +import com.datastax.oss.driver.internal.core.protocol.LwtInfo; import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; import com.datastax.oss.driver.shaded.guava.common.primitives.Ints; @@ -363,7 +364,7 @@ public static ColumnDefinitions getResultDefinitions( } public static DefaultPreparedStatement toPreparedStatement( - Prepared response, PrepareRequest request, InternalDriverContext context) { + Prepared response, PrepareRequest request, InternalDriverContext context, LwtInfo lwtInfo) { ColumnDefinitions variableDefinitions = toColumnDefinitions(response.variablesMetadata, context); @@ -402,7 +403,8 @@ public static DefaultPreparedStatement toPreparedStatement( request.getSerialConsistencyLevelForBoundStatements(), request.areBoundStatementsTracing(), context.getCodecRegistry(), - context.getProtocolVersion()); + context.getProtocolVersion(), + lwtInfo != null && lwtInfo.isLwt(response.variablesMetadata.flags)); } public static ColumnDefinitions toColumnDefinitions( diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareHandler.java index d60a6c65260..6448d329a6c 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareHandler.java @@ -232,13 +232,13 @@ private void recordError(Node node, Throwable error) { errorsSnapshot.add(new AbstractMap.SimpleEntry<>(node, error)); } - private void setFinalResult(PrepareRequest request, Prepared response) { + private void setFinalResult(PrepareRequest request, Prepared response, DriverChannel channel) { // Whatever happens below, we're done with this stream id throttler.signalSuccess(this); DefaultPreparedStatement preparedStatement = - Conversions.toPreparedStatement(response, request, context); + Conversions.toPreparedStatement(response, request, context, channel.getLwtInfo()); session .getRepreparePayloads() @@ -375,7 +375,7 @@ public void onResponse(Frame responseFrame) { Message responseMessage = responseFrame.message; if (responseMessage instanceof Prepared) { LOG.trace("[{}] Got result, completing", logPrefix); - setFinalResult(request, (Prepared) responseMessage); + setFinalResult(request, (Prepared) responseMessage, channel); } else if (responseMessage instanceof Error) { LOG.trace("[{}] Got error response, processing", logPrefix); processErrorResponse((Error) responseMessage); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java index 9af95ed37c7..a27324d1f28 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlRequestHandler.java @@ -90,9 +90,11 @@ import java.util.List; import java.util.Map; import java.util.Queue; +import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -195,15 +197,47 @@ public void onThrottleReady(boolean wasDelayed) { System.nanoTime() - startTimeNanos, TimeUnit.NANOSECONDS); } - Queue queryPlan = - this.initialStatement.getNode() != null - ? new SimpleQueryPlan(this.initialStatement.getNode()) - : context - .getLoadBalancingPolicyWrapper() - .newQueryPlan(initialStatement, executionProfile.getName(), session); + Queue queryPlan; + if (this.initialStatement.getNode() != null) { + queryPlan = new SimpleQueryPlan(this.initialStatement.getNode()); + } else if (this.initialStatement.isLWT()) { + queryPlan = + getReplicas( + session.getKeyspace().orElse(null), + this.initialStatement, + context + .getLoadBalancingPolicyWrapper() + .newQueryPlan(initialStatement, executionProfile.getName(), session)); + } else { + queryPlan = + context + .getLoadBalancingPolicyWrapper() + .newQueryPlan(initialStatement, executionProfile.getName(), session); + } + sendRequest(initialStatement, null, queryPlan, 0, 0, true); } + private Queue getReplicas( + CqlIdentifier loggedKeyspace, Statement statement, Queue fallback) { + Token routingToken = getRoutingToken(statement); + CqlIdentifier keyspace = statement.getKeyspace(); + if (keyspace == null) { + keyspace = statement.getRoutingKeyspace(); + if (keyspace == null) { + keyspace = loggedKeyspace; + } + } + + TokenMap tokenMap = context.getMetadataManager().getMetadata().getTokenMap().orElse(null); + if (routingToken == null || keyspace == null || tokenMap == null) { + return fallback; + } + + Set replicas = tokenMap.getReplicas(keyspace, routingToken); + return new ConcurrentLinkedQueue<>(replicas); + } + public CompletionStage handle() { return result; } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index 94d704c51ad..0068d3c8b98 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -13,6 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +/* + * Copyright (C) 2022 ScyllaDB + * + * Modified by ScyllaDB + */ package com.datastax.oss.driver.internal.core.cql; import com.datastax.oss.driver.api.core.ConsistencyLevel; @@ -783,4 +789,9 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { node, newNowInSeconds); } + + @Override + public boolean isLWT() { + return false; + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java index 0520ee050c6..c815be00263 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBoundStatement.java @@ -770,4 +770,9 @@ public BoundStatement setNowInSeconds(int newNowInSeconds) { node, newNowInSeconds); } + + @Override + public boolean isLWT() { + return this.getPreparedStatement().isLWT(); + } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java index d4a296a4814..bae23bfce22 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultPreparedStatement.java @@ -66,6 +66,7 @@ public class DefaultPreparedStatement implements PreparedStatement { private final ConsistencyLevel serialConsistencyLevelForBoundStatements; private final Duration timeoutForBoundStatements; private final Partitioner partitioner; + private final boolean isLWT; public DefaultPreparedStatement( ByteBuffer id, @@ -91,7 +92,8 @@ public DefaultPreparedStatement( ConsistencyLevel serialConsistencyLevelForBoundStatements, boolean areBoundStatementsTracing, CodecRegistry codecRegistry, - ProtocolVersion protocolVersion) { + ProtocolVersion protocolVersion, + boolean isLWT) { this.id = id; this.partitionKeyIndices = partitionKeyIndices; // It's important that we keep a reference to this object, so that it only gets evicted from @@ -117,6 +119,7 @@ public DefaultPreparedStatement( this.codecRegistry = codecRegistry; this.protocolVersion = protocolVersion; + this.isLWT = isLWT; } @NonNull @@ -159,6 +162,11 @@ public ColumnDefinitions getResultSetDefinitions() { return resultMetadata.resultSetDefinitions; } + @Override + public boolean isLWT() { + return isLWT; + } + @Override public void setResultMetadata( @NonNull ByteBuffer newResultMetadataId, @NonNull ColumnDefinitions newResultSetDefinitions) { diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java index 4efc80a7dcc..00ae64b7c02 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultSimpleStatement.java @@ -13,6 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +/* + * Copyright (C) 2022 ScyllaDB + * + * Modified by ScyllaDB + */ package com.datastax.oss.driver.internal.core.cql; import com.datastax.oss.driver.api.core.ConsistencyLevel; @@ -741,6 +747,11 @@ public SimpleStatement setNowInSeconds(int newNowInSeconds) { newNowInSeconds); } + @Override + public boolean isLWT() { + return false; + } + public static Map wrapKeys(Map namedValues) { NullAllowingImmutableMap.Builder builder = NullAllowingImmutableMap.builder(); diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/protocol/LwtInfo.java b/core/src/main/java/com/datastax/oss/driver/internal/core/protocol/LwtInfo.java new file mode 100644 index 00000000000..5ac8abc2e53 --- /dev/null +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/protocol/LwtInfo.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2022 ScyllaDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.datastax.oss.driver.internal.core.protocol; + +import java.util.List; +import java.util.Map; + +public class LwtInfo { + private static final String SCYLLA_LWT_ADD_METADATA_MARK_KEY = "SCYLLA_LWT_ADD_METADATA_MARK"; + private static final String LWT_OPTIMIZATION_META_BIT_MASK_KEY = "LWT_OPTIMIZATION_META_BIT_MASK"; + + private final int mask; + + private LwtInfo(int m) { + mask = m; + } + + public int getMask() { + return mask; + } + + public boolean isLwt(int flags) { + return (flags & mask) == mask; + } + + public static LwtInfo parseLwtInfo(Map> supported) { + if (!supported.containsKey(SCYLLA_LWT_ADD_METADATA_MARK_KEY)) { + return null; + } + List list = supported.get(SCYLLA_LWT_ADD_METADATA_MARK_KEY); + if (list == null || list.size() != 1) { + return null; + } + String val = list.get(0); + if (val == null || !val.startsWith(LWT_OPTIMIZATION_META_BIT_MASK_KEY + "=")) { + return null; + } + long mask; + try { + mask = Long.parseLong(val.substring((LWT_OPTIMIZATION_META_BIT_MASK_KEY + "=").length())); + } catch (Exception e) { + System.err.println( + "Error while parsing " + LWT_OPTIMIZATION_META_BIT_MASK_KEY + ": " + e.getMessage()); + return null; + } + if (mask > Integer.MAX_VALUE) { + // Unfortunately server returns mask as unsigned int32 so we have to parse it as int64 and + // convert to proper signed int32 + mask += Integer.MIN_VALUE; + mask += Integer.MIN_VALUE; + } + + return new LwtInfo((int) mask); + } + + public void addOption(Map options) { + options.put(SCYLLA_LWT_ADD_METADATA_MARK_KEY, Integer.toString(mask)); + } +} diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java index f1b3a21beb8..b9bb43a1bd4 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/tracker/RequestLogFormatterTest.java @@ -291,6 +291,7 @@ private PreparedStatement mockPreparedStatement(String query, Map SESSION_RULE = + SessionRule.builder(CCM_RULE).withKeyspace(false).build(); + + @ClassRule + public static final TestRule CHAIN = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE); + + @BeforeClass + public static void setup() { + CqlIdentifier keyspace = SessionUtils.uniqueKeyspaceId(); + CqlSession session = SESSION_RULE.session(); + session.execute( + "CREATE KEYSPACE " + + keyspace.asCql(false) + + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}"); + session.execute("USE " + keyspace.asCql(false)); + session.execute("CREATE TABLE foo (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); + } + + @Test + public void should_use_only_one_node_when_lwt_detected() { + assumeTrue(CcmBridge.SCYLLA_ENABLEMENT); // Functionality only available in Scylla + CqlSession session = SESSION_RULE.session(); + int pk = 1234; + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Node owner = tokenMap.getReplicas(session.getKeyspace().get(), routingKey).iterator().next(); + PreparedStatement statement = + SESSION_RULE + .session() + .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + assertThat(statement.isLWT()).isTrue(); + for (int i = 0; i < 30; i++) { + ResultSet result = session.execute(statement.bind(pk, i, 123)); + assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + } + } + + @Test + // Sanity check for the previous test - non-LWT queries should + // not always be sent to same node + public void should_not_use_only_one_node_when_non_lwt() { + CqlSession session = SESSION_RULE.session(); + int pk = 1234; + PreparedStatement statement = session.prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?)"); + assertThat(statement.isLWT()).isFalse(); + Set coordinators = new HashSet<>(); + for (int i = 0; i < 30; i++) { + ResultSet result = session.execute(statement.bind(pk, i, 123)); + coordinators.add(result.getExecutionInfo().getCoordinator()); + } + + // Because keyspace RF == 3 + assertThat(coordinators.size()).isEqualTo(3); + } +} diff --git a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/CassandraSkip.java b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/CassandraSkip.java new file mode 100644 index 00000000000..521b257cc0e --- /dev/null +++ b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/CassandraSkip.java @@ -0,0 +1,29 @@ +/* + * Copyright (C) 2022 ScyllaDB + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.api.testinfra; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * Annotation for a Class or Method that skips it for Cassandra. If the tests are run against + * Cassandra, the test is skipped. + */ +@Retention(RetentionPolicy.RUNTIME) +public @interface CassandraSkip { + /** @return The description returned if this requirement is not met. */ + String description() default "Disabled for Cassandra."; +} diff --git a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/BaseCcmRule.java b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/BaseCcmRule.java index 91b959c6e46..55bc67d84c5 100644 --- a/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/BaseCcmRule.java +++ b/test-infra/src/main/java/com/datastax/oss/driver/api/testinfra/ccm/BaseCcmRule.java @@ -24,10 +24,7 @@ import com.datastax.oss.driver.api.core.DefaultProtocolVersion; import com.datastax.oss.driver.api.core.ProtocolVersion; import com.datastax.oss.driver.api.core.Version; -import com.datastax.oss.driver.api.testinfra.CassandraRequirement; -import com.datastax.oss.driver.api.testinfra.CassandraResourceRule; -import com.datastax.oss.driver.api.testinfra.DseRequirement; -import com.datastax.oss.driver.api.testinfra.ScyllaSkip; +import com.datastax.oss.driver.api.testinfra.*; import java.util.Optional; import org.junit.AssumptionViolatedException; import org.junit.runner.Description; @@ -98,6 +95,21 @@ public void evaluate() { } } + CassandraSkip cassandraSkip = description.getAnnotation(CassandraSkip.class); + if (cassandraSkip != null) { + if (!CcmBridge.SCYLLA_ENABLEMENT) { + return new Statement() { + + @Override + public void evaluate() { + throw new AssumptionViolatedException( + String.format( + "Test skipped when running with Cassandra. Description: %s", description)); + } + }; + } + } + // If test is annotated with CassandraRequirement or DseRequirement, ensure configured CCM // cluster meets those requirements. CassandraRequirement cassandraRequirement =