diff --git a/build.gradle b/build.gradle index 2fec351718b7..3df7068bf334 100644 --- a/build.gradle +++ b/build.gradle @@ -26,6 +26,7 @@ configure(allprojects) { project -> mavenBom "com.fasterxml.jackson:jackson-bom:2.11.0" mavenBom "io.netty:netty-bom:4.1.50.Final" mavenBom "io.projectreactor:reactor-bom:2020.0.0-SNAPSHOT" + mavenBom "io.r2dbc:r2dbc-bom:Arabba-SR5" mavenBom "io.rsocket:rsocket-bom:1.0.1" mavenBom "org.eclipse.jetty:jetty-bom:9.4.29.v20200521" mavenBom "org.jetbrains.kotlin:kotlin-bom:1.3.72" diff --git a/integration-tests/integration-tests.gradle b/integration-tests/integration-tests.gradle index 71e23b906049..57cc7ac2bf7d 100644 --- a/integration-tests/integration-tests.gradle +++ b/integration-tests/integration-tests.gradle @@ -11,6 +11,7 @@ dependencies { testCompile(testFixtures(project(":spring-tx"))) testCompile(project(":spring-expression")) testCompile(project(":spring-jdbc")) + testCompile(project(":spring-r2dbc")) testCompile(project(":spring-orm")) testCompile(project(":spring-test")) testCompile(project(":spring-tx")) diff --git a/settings.gradle b/settings.gradle index 1d84e98f4230..5ceee6411652 100644 --- a/settings.gradle +++ b/settings.gradle @@ -23,6 +23,7 @@ include "spring-expression" include "spring-instrument" include "spring-jcl" include "spring-jdbc" +include "spring-r2dbc" include "spring-jms" include "spring-messaging" include "spring-orm" diff --git a/spring-r2dbc/spring-r2dbc.gradle b/spring-r2dbc/spring-r2dbc.gradle new file mode 100644 index 000000000000..cfac71726e5a --- /dev/null +++ b/spring-r2dbc/spring-r2dbc.gradle @@ -0,0 +1,23 @@ +description = "Spring R2DBC" + +apply plugin: "kotlin" + +dependencies { + compile(project(":spring-beans")) + compile(project(":spring-core")) + compile(project(":spring-tx")) + compile("io.r2dbc:r2dbc-spi") + compile("io.projectreactor:reactor-core") + compileOnly(project(":kotlin-coroutines")) + optional("org.jetbrains.kotlin:kotlin-reflect") + optional("org.jetbrains.kotlin:kotlin-stdlib") + optional("org.jetbrains.kotlinx:kotlinx-coroutines-core") + optional("org.jetbrains.kotlinx:kotlinx-coroutines-reactor") + testCompile(project(":kotlin-coroutines")) + testCompile(testFixtures(project(":spring-beans"))) + testCompile(testFixtures(project(":spring-core"))) + testCompile(testFixtures(project(":spring-context"))) + testCompile("io.projectreactor:reactor-test") + testCompile("io.r2dbc:r2dbc-h2") + testCompile("io.r2dbc:r2dbc-spi-test:0.8.1.RELEASE") +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/BadSqlGrammarException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/BadSqlGrammarException.java new file mode 100644 index 000000000000..126a7853168c --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/BadSqlGrammarException.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc; + +import io.r2dbc.spi.R2dbcException; + +import org.springframework.dao.InvalidDataAccessResourceUsageException; + + +/** + * Exception thrown when SQL specified is invalid. Such exceptions always have a + * {@link io.r2dbc.spi.R2dbcException} root cause. + * + *

It would be possible to have subclasses for no such table, no such column etc. + * A custom R2dbcExceptionTranslator could create such more specific exceptions, + * without affecting code using this class. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class BadSqlGrammarException extends InvalidDataAccessResourceUsageException { + + private final String sql; + + + /** + * Constructor for BadSqlGrammarException. + * @param task name of current task + * @param sql the offending SQL statement + * @param ex the root cause + */ + public BadSqlGrammarException(String task, String sql, R2dbcException ex) { + super(task + "; bad SQL grammar [" + sql + "]", ex); + this.sql = sql; + } + + + /** + * Return the wrapped {@link R2dbcException}. + */ + public R2dbcException getR2dbcException() { + return (R2dbcException) getCause(); + } + + /** + * Return the SQL that caused the problem. + */ + public String getSql() { + return this.sql; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/UncategorizedR2dbcException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/UncategorizedR2dbcException.java new file mode 100644 index 000000000000..2944f01f3a46 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/UncategorizedR2dbcException.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc; + +import io.r2dbc.spi.R2dbcException; + +import org.springframework.dao.UncategorizedDataAccessException; +import org.springframework.lang.Nullable; + +/** + * Exception thrown when we can't classify a {@link R2dbcException} into + * one of our generic data access exceptions. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class UncategorizedR2dbcException extends UncategorizedDataAccessException { + + /** SQL that led to the problem. */ + @Nullable + private final String sql; + + + /** + * Constructor for {@code UncategorizedSQLException}. + * @param msg the detail message + * @param sql the offending SQL statement + * @param ex the exception thrown by underlying data access API + */ + public UncategorizedR2dbcException(String msg, @Nullable String sql, R2dbcException ex) { + super(msg, ex); + this.sql = sql; + } + + + /** + * Return the wrapped {@link R2dbcException}. + */ + public R2dbcException getR2dbcException() { + return (R2dbcException) getCause(); + } + + /** + * Return the SQL that led to the problem (if known). + */ + @Nullable + public String getSql() { + return this.sql; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionFactoryUtils.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionFactoryUtils.java new file mode 100644 index 000000000000..ef1a0e582e87 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionFactoryUtils.java @@ -0,0 +1,442 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.R2dbcBadGrammarException; +import io.r2dbc.spi.R2dbcDataIntegrityViolationException; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.R2dbcNonTransientException; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.r2dbc.spi.R2dbcPermissionDeniedException; +import io.r2dbc.spi.R2dbcRollbackException; +import io.r2dbc.spi.R2dbcTimeoutException; +import io.r2dbc.spi.R2dbcTransientException; +import io.r2dbc.spi.R2dbcTransientResourceException; +import io.r2dbc.spi.Wrapped; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import reactor.core.publisher.Mono; + +import org.springframework.core.Ordered; +import org.springframework.dao.ConcurrencyFailureException; +import org.springframework.dao.DataAccessException; +import org.springframework.dao.DataAccessResourceFailureException; +import org.springframework.dao.DataIntegrityViolationException; +import org.springframework.dao.PermissionDeniedDataAccessException; +import org.springframework.dao.QueryTimeoutException; +import org.springframework.dao.TransientDataAccessResourceException; +import org.springframework.lang.Nullable; +import org.springframework.r2dbc.BadSqlGrammarException; +import org.springframework.r2dbc.UncategorizedR2dbcException; +import org.springframework.transaction.NoTransactionException; +import org.springframework.transaction.reactive.TransactionSynchronization; +import org.springframework.transaction.reactive.TransactionSynchronizationManager; +import org.springframework.util.Assert; + +/** + * Helper class that provides static methods for obtaining R2DBC Connections from + * a {@link ConnectionFactory}. + * + *

Used internally by Spring's {@code DatabaseClient}, Spring's R2DBC operation + * objects. Can also be used directly in application code. + * + * @author Mark Paluch + * @author Christoph Strobl + * @since 5.3 + * @see R2dbcTransactionManager + * @see org.springframework.transaction.reactive.TransactionSynchronizationManager + */ +public abstract class ConnectionFactoryUtils { + + /** + * Order value for ReactiveTransactionSynchronization objects that clean up R2DBC Connections. + */ + public static final int CONNECTION_SYNCHRONIZATION_ORDER = 1000; + + private static final Log logger = LogFactory.getLog(ConnectionFactoryUtils.class); + + + private ConnectionFactoryUtils() {} + + + /** + * Obtain a {@link Connection} from the given {@link ConnectionFactory}. + * Translates exceptions into the Spring hierarchy of unchecked generic + * data access exceptions, simplifying calling code and making any + * exception that is thrown more meaningful. + * + *

Is aware of a corresponding Connection bound to the current + * {@link TransactionSynchronizationManager}. Will bind a Connection to the + * {@link TransactionSynchronizationManager} if transaction synchronization is active. + * @param connectionFactory the {@link ConnectionFactory} to obtain + * {@link Connection Connections} from + * @return a R2DBC Connection from the given {@link ConnectionFactory} + * @throws DataAccessResourceFailureException if the attempt to get a + * {@link Connection} failed + * @see #releaseConnection + */ + public static Mono getConnection(ConnectionFactory connectionFactory) { + return doGetConnection(connectionFactory) + .onErrorMap(e -> new DataAccessResourceFailureException("Failed to obtain R2DBC Connection", e)); + } + + /** + * Actually obtain a R2DBC Connection from the given {@link ConnectionFactory}. + * Same as {@link #getConnection}, but preserving the original exceptions. + * + *

Is aware of a corresponding Connection bound to the current + * {@link TransactionSynchronizationManager}. Will bind a Connection to the + * {@link TransactionSynchronizationManager} if transaction synchronization is active + * @param connectionFactory the {@link ConnectionFactory} to obtain Connections from + * @return a R2DBC {@link Connection} from the given {@link ConnectionFactory}. + */ + public static Mono doGetConnection(ConnectionFactory connectionFactory) { + Assert.notNull(connectionFactory, "ConnectionFactory must not be null"); + return TransactionSynchronizationManager.forCurrentTransaction().flatMap(synchronizationManager -> { + + ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory); + if (conHolder != null && (conHolder.hasConnection() || conHolder.isSynchronizedWithTransaction())) { + conHolder.requested(); + if (!conHolder.hasConnection()) { + + if (logger.isDebugEnabled()) { + logger.debug("Fetching resumed R2DBC Connection from ConnectionFactory"); + } + return fetchConnection(connectionFactory).doOnNext(conHolder::setConnection); + } + return Mono.just(conHolder.getConnection()); + } + // Else we either got no holder or an empty thread-bound holder here. + + if (logger.isDebugEnabled()) { + logger.debug("Fetching R2DBC Connection from ConnectionFactory"); + } + + Mono con = fetchConnection(connectionFactory); + + if (synchronizationManager.isSynchronizationActive()) { + + return con.flatMap(connection -> { + return Mono.just(connection).doOnNext(conn -> { + + // Use same Connection for further R2DBC actions within the transaction. + // Thread-bound object will get removed by synchronization at transaction completion. + ConnectionHolder holderToUse = conHolder; + if (holderToUse == null) { + holderToUse = new ConnectionHolder(conn); + } + else { + holderToUse.setConnection(conn); + } + holderToUse.requested(); + synchronizationManager + .registerSynchronization(new ConnectionSynchronization(holderToUse, connectionFactory)); + holderToUse.setSynchronizedWithTransaction(true); + if (holderToUse != conHolder) { + synchronizationManager.bindResource(connectionFactory, holderToUse); + } + }) // Unexpected exception from external delegation call -> close Connection and rethrow. + .onErrorResume(e -> releaseConnection(connection, connectionFactory).then(Mono.error(e))); + }); + } + + return con; + }).onErrorResume(NoTransactionException.class, e -> Mono.from(connectionFactory.create())); + } + + /** + * Actually fetch a {@link Connection} from the given {@link ConnectionFactory}. + * @param connectionFactory the {@link ConnectionFactory} to obtain + * {@link Connection}s from + * @return a R2DBC {@link Connection} from the given {@link ConnectionFactory} + * (never {@code null}). + * @throws IllegalStateException if the {@link ConnectionFactory} returned a {@code null} value. + * @see ConnectionFactory#create() + */ + private static Mono fetchConnection(ConnectionFactory connectionFactory) { + return Mono.from(connectionFactory.create()); + } + + /** + * Close the given {@link Connection}, obtained from the given {@link ConnectionFactory}, if + * it is not managed externally (that is, not bound to the subscription). + * @param con the {@link Connection} to close if necessary + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from + * @see #getConnection + */ + public static Mono releaseConnection(Connection con, ConnectionFactory connectionFactory) { + return doReleaseConnection(con, connectionFactory) + .onErrorMap(e -> new DataAccessResourceFailureException("Failed to close R2DBC Connection", e)); + } + + /** + * Actually close the given {@link Connection}, obtained from the given + * {@link ConnectionFactory}. Same as {@link #releaseConnection}, + * but preserving the original exception. + * @param connection the {@link Connection} to close if necessary + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from + * @see #doGetConnection + */ + public static Mono doReleaseConnection(Connection connection, + ConnectionFactory connectionFactory) { + return TransactionSynchronizationManager.forCurrentTransaction() + .flatMap(synchronizationManager -> { + + ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory); + if (conHolder != null && connectionEquals(conHolder, connection)) { + // It's the transactional Connection: Don't close it. + conHolder.released(); + } + return Mono.from(connection.close()); + }).onErrorResume(NoTransactionException.class, e -> Mono.from(connection.close())); + } + + /** + * Obtain the {@link ConnectionFactory} from the current {@link TransactionSynchronizationManager}. + * @param connectionFactory the {@link ConnectionFactory} that the Connection was obtained from + * @see TransactionSynchronizationManager + */ + public static Mono currentConnectionFactory(ConnectionFactory connectionFactory) { + return TransactionSynchronizationManager.forCurrentTransaction() + .filter(TransactionSynchronizationManager::isSynchronizationActive) + .filter(synchronizationManager -> { + ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(connectionFactory); + return conHolder != null && (conHolder.hasConnection() || conHolder.isSynchronizedWithTransaction()); + }).map(synchronizationManager -> connectionFactory); + } + + /** + * Translate the given {@link R2dbcException} into a generic {@link DataAccessException}. + *

The returned DataAccessException is supposed to contain the original + * {@link R2dbcException} as root cause. However, client code may not generally + * rely on this due to DataAccessExceptions possibly being caused by other resource + * APIs as well. That said, a {@code getRootCause() instanceof R2dbcException} + * check (and subsequent cast) is considered reliable when expecting R2DBC-based + * access to have happened. + * @param task readable text describing the task being attempted + * @param sql the SQL query or update that caused the problem (if known) + * @param ex the offending {@link R2dbcException} + * @return the corresponding DataAccessException instance + */ + public static DataAccessException convertR2dbcException(String task, @Nullable String sql, R2dbcException ex) { + + if (ex instanceof R2dbcTransientException) { + if (ex instanceof R2dbcTransientResourceException) { + return new TransientDataAccessResourceException(buildMessage(task, sql, ex), ex); + } + if (ex instanceof R2dbcRollbackException) { + return new ConcurrencyFailureException(buildMessage(task, sql, ex), ex); + } + if (ex instanceof R2dbcTimeoutException) { + return new QueryTimeoutException(buildMessage(task, sql, ex), ex); + } + } + + if (ex instanceof R2dbcNonTransientException) { + if (ex instanceof R2dbcNonTransientResourceException) { + return new DataAccessResourceFailureException(buildMessage(task, sql, ex), ex); + } + if (ex instanceof R2dbcDataIntegrityViolationException) { + return new DataIntegrityViolationException(buildMessage(task, sql, ex), ex); + } + if (ex instanceof R2dbcPermissionDeniedException) { + return new PermissionDeniedDataAccessException(buildMessage(task, sql, ex), ex); + } + if (ex instanceof R2dbcBadGrammarException) { + return new BadSqlGrammarException(task, (sql != null ? sql : ""), ex); + } + } + + return new UncategorizedR2dbcException(buildMessage(task, sql, ex), sql, ex); + } + + /** + * Build a message {@code String} for the given {@link R2dbcException}. + *

To be called by translator subclasses when creating an instance of a generic + * {@link org.springframework.dao.DataAccessException} class. + * @param task readable text describing the task being attempted + * @param sql the SQL statement that caused the problem + * @param ex the offending {@code R2dbcException} + * @return the message {@code String} to use + */ + private static String buildMessage(String task, @Nullable String sql, R2dbcException ex) { + return task + "; " + (sql != null ? ("SQL [" + sql + "]; ") : "") + ex.getMessage(); + } + + /** + * Determine whether the given two {@link Connection}s are equal, asking the target + * {@link Connection} in case of a proxy. Used to detect equality even if the user + * passed in a raw target Connection while the held one is a proxy. + * @param conHolder the {@link ConnectionHolder} for the held {@link Connection} (potentially a proxy) + * @param passedInCon the {@link Connection} passed-in by the user (potentially + * a target {@link Connection} without proxy). + * @return whether the given Connections are equal + * @see #getTargetConnection + */ + private static boolean connectionEquals(ConnectionHolder conHolder, Connection passedInCon) { + if (!conHolder.hasConnection()) { + return false; + } + Connection heldCon = conHolder.getConnection(); + // Explicitly check for identity too: for Connection handles that do not implement + // "equals" properly). + return (heldCon == passedInCon || heldCon.equals(passedInCon) || getTargetConnection(heldCon).equals(passedInCon)); + } + + /** + * Return the innermost target {@link Connection} of the given {@link Connection}. + * If the given {@link Connection} is wrapped, it will be unwrapped until a + * plain {@link Connection} is found. Otherwise, the passed-in Connection + * will be returned as-is. + * @param con the {@link Connection} wrapper to unwrap + * @return the innermost target Connection, or the passed-in one if not wrapped + * @see Wrapped#unwrap() + */ + @SuppressWarnings("unchecked") + public static Connection getTargetConnection(Connection con) { + Connection conToUse = con; + while (conToUse instanceof Wrapped) { + conToUse = ((Wrapped) conToUse).unwrap(); + } + return conToUse; + } + + /** + * Determine the connection synchronization order to use for the given {@link ConnectionFactory}. + * Decreased for every level of nesting that a {@link ConnectionFactory} has, + * checked through the level of {@link DelegatingConnectionFactory} nesting. + * @param connectionFactory the {@link ConnectionFactory} to check + * @return the connection synchronization order to use + * @see #CONNECTION_SYNCHRONIZATION_ORDER + */ + private static int getConnectionSynchronizationOrder(ConnectionFactory connectionFactory) { + + int order = CONNECTION_SYNCHRONIZATION_ORDER; + ConnectionFactory current = connectionFactory; + while (current instanceof DelegatingConnectionFactory) { + order--; + current = ((DelegatingConnectionFactory) current).getTargetConnectionFactory(); + } + return order; + } + + /** + * Callback for resource cleanup at the end of a non-native R2DBC transaction. + */ + private static class ConnectionSynchronization implements TransactionSynchronization, Ordered { + + private final ConnectionHolder connectionHolder; + + private final ConnectionFactory connectionFactory; + + private final int order; + + private boolean holderActive = true; + + ConnectionSynchronization(ConnectionHolder connectionHolder, ConnectionFactory connectionFactory) { + this.connectionHolder = connectionHolder; + this.connectionFactory = connectionFactory; + this.order = getConnectionSynchronizationOrder(connectionFactory); + } + + + @Override + public int getOrder() { + return this.order; + } + + @Override + public Mono suspend() { + if (this.holderActive) { + return TransactionSynchronizationManager.forCurrentTransaction() + .flatMap(synchronizationManager -> { + + synchronizationManager.unbindResource(this.connectionFactory); + if (this.connectionHolder.hasConnection() && !this.connectionHolder.isOpen()) { + // Release Connection on suspend if the application doesn't keep + // a handle to it anymore. We will fetch a fresh Connection if the + // application accesses the ConnectionHolder again after resume, + // assuming that it will participate in the same transaction. + return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory) + .doOnTerminate(() -> this.connectionHolder.setConnection(null)); + } + return Mono.empty(); + }); + } + + return Mono.empty(); + } + + @Override + public Mono resume() { + if (this.holderActive) { + return TransactionSynchronizationManager.forCurrentTransaction() + .doOnNext(synchronizationManager -> synchronizationManager.bindResource(this.connectionFactory, this.connectionHolder)) + .then(); + } + return Mono.empty(); + } + + @Override + public Mono beforeCompletion() { + // Release Connection early if the holder is not open anymore + // (that is, not used by another resource + // that has its own cleanup via transaction synchronization), + // to avoid issues with strict transaction implementations that expect + // the close call before transaction completion. + if (!this.connectionHolder.isOpen()) { + return TransactionSynchronizationManager.forCurrentTransaction() + .flatMap(synchronizationManager -> { + synchronizationManager.unbindResource(this.connectionFactory); + this.holderActive = false; + if (this.connectionHolder.hasConnection()) { + return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory); + } + return Mono.empty(); + }); + } + + return Mono.empty(); + } + + @Override + public Mono afterCompletion(int status) { + // If we haven't closed the Connection in beforeCompletion, + // close it now. + if (this.holderActive) { + // The bound ConnectionHolder might not be available anymore, + // since afterCompletion might get called from a different thread. + return TransactionSynchronizationManager.forCurrentTransaction() + .flatMap(synchronizationManager -> { + synchronizationManager.unbindResourceIfPossible(this.connectionFactory); + this.holderActive = false; + if (this.connectionHolder.hasConnection()) { + return releaseConnection(this.connectionHolder.getConnection(), this.connectionFactory) + // Reset the ConnectionHolder: It might remain bound to the context. + .doOnTerminate(() -> this.connectionHolder.setConnection(null)); + } + return Mono.empty(); + }); + } + + this.connectionHolder.reset(); + return Mono.empty(); + } + } +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionHolder.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionHolder.java new file mode 100644 index 000000000000..c4efc83c2742 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/ConnectionHolder.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.lang.Nullable; +import org.springframework.transaction.support.ResourceHolderSupport; +import org.springframework.util.Assert; + + +/** + * Resource holder wrapping a R2DBC {@link Connection}. + * {@link R2dbcTransactionManager} binds instances of this class to the subscription, + * for a specific {@link ConnectionFactory}. + * + *

Inherits rollback-only support for nested R2DBC transactions and reference + * count functionality from the base class. + * + *

Note: This is an SPI class, not intended to be used by applications. + * + * @author Mark Paluch + * @author Christoph Strobl + * @since 5.3 + * @see R2dbcTransactionManager + * @see ConnectionFactoryUtils + */ +public class ConnectionHolder extends ResourceHolderSupport { + + @Nullable + private Connection currentConnection; + + private boolean transactionActive; + + + /** + * Create a new ConnectionHolder for the given R2DBC {@link Connection}, + * assuming that there is no ongoing transaction. + * @param connection the R2DBC {@link Connection} to hold + * @see #ConnectionHolder(Connection, boolean) + */ + public ConnectionHolder(Connection connection) { + this(connection, false); + } + + /** + * Create a new ConnectionHolder for the given R2DBC {@link Connection}. + * @param connection the R2DBC {@link Connection} to hold + * @param transactionActive whether the given {@link Connection} is involved + * in an ongoing transaction + */ + public ConnectionHolder(Connection connection, boolean transactionActive) { + + this.currentConnection = connection; + this.transactionActive = transactionActive; + } + + + /** + * Return whether this holder currently has a {@link Connection}. + */ + protected boolean hasConnection() { + return (this.currentConnection != null); + } + + /** + * Set whether this holder represents an active, R2DBC-managed transaction. + * + * @see R2dbcTransactionManager + */ + protected void setTransactionActive(boolean transactionActive) { + this.transactionActive = transactionActive; + } + + /** + * Return whether this holder represents an active, R2DBC-managed transaction. + */ + protected boolean isTransactionActive() { + return this.transactionActive; + } + + /** + * Override the existing Connection with the given {@link Connection}. + *

Used for releasing the {@link Connection} on suspend + * (with a {@code null} argument) and setting a fresh {@link Connection} on resume. + */ + protected void setConnection(@Nullable Connection connection) { + this.currentConnection = connection; + } + + /** + * Return the current {@link Connection} held by this {@link ConnectionHolder}. + *

This will be the same {@link Connection} until {@code released} gets called + * on the {@link ConnectionHolder}, which will reset the held {@link Connection}, + * fetching a new {@link Connection} on demand. + * @see #released() + */ + public Connection getConnection() { + + Assert.notNull(this.currentConnection, "Active Connection is required"); + return this.currentConnection; + } + + /** + * Releases the current {@link Connection}. + */ + @Override + public void released() { + super.released(); + if (!isOpen() && this.currentConnection != null) { + this.currentConnection = null; + } + } + + @Override + public void clear() { + super.clear(); + this.transactionActive = false; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/DelegatingConnectionFactory.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/DelegatingConnectionFactory.java new file mode 100644 index 000000000000..69615f5b4dfc --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/DelegatingConnectionFactory.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import io.r2dbc.spi.Wrapped; +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; + +/** + * R2DBC {@link ConnectionFactory} implementation that delegates + * all calls to a given target {@link ConnectionFactory}. + * + *

This class is meant to be subclassed, with subclasses overriding + * only those methods (such as {@link #create()}) that should not simply + * delegate to the target {@link ConnectionFactory}. + * + * @author Mark Paluch + * @since 5.3 + * @see #create + */ +public class DelegatingConnectionFactory implements ConnectionFactory, Wrapped { + + private final ConnectionFactory targetConnectionFactory; + + + public DelegatingConnectionFactory(ConnectionFactory targetConnectionFactory) { + Assert.notNull(targetConnectionFactory, "ConnectionFactory must not be null"); + this.targetConnectionFactory = targetConnectionFactory; + } + + + @Override + public Mono create() { + return Mono.from(this.targetConnectionFactory.create()); + } + + public ConnectionFactory getTargetConnectionFactory() { + return this.targetConnectionFactory; + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + return obtainTargetConnectionFactory().getMetadata(); + } + + @Override + public ConnectionFactory unwrap() { + return obtainTargetConnectionFactory(); + } + + /** + * Obtain the target {@link ConnectionFactory} for actual use (never {@code null}). + */ + protected ConnectionFactory obtainTargetConnectionFactory() { + return getTargetConnectionFactory(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java new file mode 100644 index 000000000000..c3f553021a82 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/R2dbcTransactionManager.java @@ -0,0 +1,538 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import java.time.Duration; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.Result; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.transaction.CannotCreateTransactionException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.TransactionException; +import org.springframework.transaction.reactive.AbstractReactiveTransactionManager; +import org.springframework.transaction.reactive.GenericReactiveTransaction; +import org.springframework.transaction.reactive.TransactionSynchronizationManager; +import org.springframework.util.Assert; + +/** + * {@link org.springframework.transaction.ReactiveTransactionManager} + * implementation for a single R2DBC {@link ConnectionFactory}. This class is + * capable of working in any environment with any R2DBC driver, as long as the + * setup uses a {@link ConnectionFactory} as its {@link Connection} factory + * mechanism. Binds a R2DBC {@link Connection} from the specified + * {@link ConnectionFactory} to the current subscriber context, potentially + * allowing for one context-bound {@link Connection} per {@link ConnectionFactory}. + * + *

Note: The {@link ConnectionFactory} that this transaction manager + * operates on needs to return independent {@link Connection}s. + * The {@link Connection}s may come from a pool (the typical case), but the + * {@link ConnectionFactory} must not return scoped scoped {@link Connection}s + * or the like. This transaction manager will associate {@link Connection} + * with context-bound transactions itself, according to the specified propagation + * behavior. It assumes that a separate, independent {@link Connection} can + * be obtained even during an ongoing transaction. + * + *

Application code is required to retrieve the R2DBC Connection via + * {@link ConnectionFactoryUtils#getConnection(ConnectionFactory)} + * instead of a standard R2DBC-style {@link ConnectionFactory#create()} call. + * Spring classes such as {@code DatabaseClient} use this strategy implicitly. + * If not used in combination with this transaction manager, the + * {@link ConnectionFactoryUtils} lookup strategy behaves exactly like the + * native {@link ConnectionFactory} lookup; it can thus be used in a portable fashion. + * + *

Alternatively, you can allow application code to work with the standard + * R2DBC lookup pattern {@link ConnectionFactory#create()}, for example for code + * that is not aware of Spring at all. In that case, define a + * {@link TransactionAwareConnectionFactoryProxy} for your target {@link ConnectionFactory}, + * and pass that proxy {@link ConnectionFactory} to your DAOs, which will automatically + * participate in Spring-managed transactions when accessing it. + * + *

This transaction manager triggers flush callbacks on registered transaction + * synchronizations (if synchronization is generally active), assuming resources + * operating on the underlying R2DBC {@link Connection}. + * + * @author Mark Paluch + * @since 5.3 + * @see ConnectionFactoryUtils#getConnection(ConnectionFactory) + * @see ConnectionFactoryUtils#releaseConnection + * @see TransactionAwareConnectionFactoryProxy + */ +@SuppressWarnings("serial") +public class R2dbcTransactionManager extends AbstractReactiveTransactionManager implements InitializingBean { + + private ConnectionFactory connectionFactory; + + private boolean enforceReadOnly = false; + + + /** + * Create a new @link ConnectionFactoryTransactionManager} instance. A ConnectionFactory has to be set to be able to + * use it. + * + * @see #setConnectionFactory + */ + public R2dbcTransactionManager() {} + + /** + * Create a new {@link R2dbcTransactionManager} instance. + * + * @param connectionFactory the R2DBC ConnectionFactory to manage transactions for + */ + public R2dbcTransactionManager(ConnectionFactory connectionFactory) { + this(); + setConnectionFactory(connectionFactory); + afterPropertiesSet(); + } + + + /** + * Set the R2DBC {@link ConnectionFactory} that this instance should manage transactions for. + *

+ * This will typically be a locally defined {@link ConnectionFactory}, for example an connection pool. + *

+ * The {@link ConnectionFactory} specified here should be the target {@link ConnectionFactory} to manage transactions + * for, not a TransactionAwareConnectionFactoryProxy. Only data access code may work with + * TransactionAwareConnectionFactoryProxy, while the transaction manager needs to work on the underlying target + * {@link ConnectionFactory}. If there's nevertheless a TransactionAwareConnectionFactoryProxy passed in, it will be + * unwrapped to extract its target {@link ConnectionFactory}. + *

+ * The {@link ConnectionFactory} passed in here needs to return independent {@link Connection}s. The + * {@link Connection}s may come from a pool (the typical case), but the {@link ConnectionFactory} must not return + * scoped {@link Connection} or the like. + * + * @see TransactionAwareConnectionFactoryProxy + */ + public void setConnectionFactory(@Nullable ConnectionFactory connectionFactory) { + this.connectionFactory = connectionFactory; + } + + /** + * Return the R2DBC {@link ConnectionFactory} that this instance manages transactions for. + */ + @Nullable + public ConnectionFactory getConnectionFactory() { + return this.connectionFactory; + } + + /** + * Obtain the {@link ConnectionFactory} for actual use. + * + * @return the {@link ConnectionFactory} (never {@code null}) + * @throws IllegalStateException in case of no ConnectionFactory set + */ + protected ConnectionFactory obtainConnectionFactory() { + ConnectionFactory connectionFactory = getConnectionFactory(); + Assert.state(connectionFactory != null, "No ConnectionFactory set"); + return connectionFactory; + } + + /** + * Specify whether to enforce the read-only nature of a transaction (as indicated by + * {@link TransactionDefinition#isReadOnly()} through an explicit statement on the transactional connection: "SET + * TRANSACTION READ ONLY" as understood by Oracle, MySQL and Postgres. + *

+ * The exact treatment, including any SQL statement executed on the connection, can be customized through through + * {@link #prepareTransactionalConnection}. + * + * @see #prepareTransactionalConnection + */ + public void setEnforceReadOnly(boolean enforceReadOnly) { + this.enforceReadOnly = enforceReadOnly; + } + + /** + * Return whether to enforce the read-only nature of a transaction through an explicit statement on the transactional + * connection. + * + * @see #setEnforceReadOnly + */ + public boolean isEnforceReadOnly() { + return this.enforceReadOnly; + } + + @Override + public void afterPropertiesSet() { + if (getConnectionFactory() == null) { + throw new IllegalArgumentException("Property 'connectionFactory' is required"); + } + } + + @Override + protected Object doGetTransaction(TransactionSynchronizationManager synchronizationManager) + throws TransactionException { + ConnectionFactoryTransactionObject txObject = new ConnectionFactoryTransactionObject(); + ConnectionHolder conHolder = (ConnectionHolder) synchronizationManager.getResource(obtainConnectionFactory()); + txObject.setConnectionHolder(conHolder, false); + return txObject; + } + + @Override + protected boolean isExistingTransaction(Object transaction) { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + return (txObject.hasConnectionHolder() && txObject.getConnectionHolder().isTransactionActive()); + } + + @Override + protected Mono doBegin(TransactionSynchronizationManager synchronizationManager, Object transaction, + TransactionDefinition definition) throws TransactionException { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + + return Mono.defer(() -> { + + Mono connectionMono; + + if (!txObject.hasConnectionHolder() || txObject.getConnectionHolder().isSynchronizedWithTransaction()) { + Mono newCon = Mono.from(obtainConnectionFactory().create()); + + connectionMono = newCon.doOnNext(connection -> { + + if (logger.isDebugEnabled()) { + logger.debug("Acquired Connection [" + newCon + "] for R2DBC transaction"); + } + txObject.setConnectionHolder(new ConnectionHolder(connection), true); + }); + } + else { + txObject.getConnectionHolder().setSynchronizedWithTransaction(true); + connectionMono = Mono.just(txObject.getConnectionHolder().getConnection()); + } + + return connectionMono.flatMap(con -> { + + return prepareTransactionalConnection(con, definition, transaction).then(Mono.from(con.beginTransaction())) + .doOnSuccess(v -> { + txObject.getConnectionHolder().setTransactionActive(true); + + Duration timeout = determineTimeout(definition); + if (!timeout.isNegative() && !timeout.isZero()) { + txObject.getConnectionHolder().setTimeoutInMillis(timeout.toMillis()); + } + + // Bind the connection holder to the thread. + if (txObject.isNewConnectionHolder()) { + synchronizationManager.bindResource(obtainConnectionFactory(), txObject.getConnectionHolder()); + } + }).thenReturn(con).onErrorResume(e -> { + + if (txObject.isNewConnectionHolder()) { + return ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory()) + .doOnTerminate(() -> txObject.setConnectionHolder(null, false)) + .then(Mono.error(e)); + } + return Mono.error(e); + }); + }).onErrorResume(e -> { + + CannotCreateTransactionException ex = new CannotCreateTransactionException( + "Could not open R2DBC Connection for transaction", + e); + + return Mono.error(ex); + }); + }).then(); + } + + /** + * Determine the actual timeout to use for the given definition. Will fall back to this manager's default timeout if + * the transaction definition doesn't specify a non-default value. + * + * @param definition the transaction definition + * @return the actual timeout to use + * @see org.springframework.transaction.TransactionDefinition#getTimeout() + */ + protected Duration determineTimeout(TransactionDefinition definition) { + if (definition.getTimeout() != TransactionDefinition.TIMEOUT_DEFAULT) { + return Duration.ofSeconds(definition.getTimeout()); + } + return Duration.ZERO; + } + + @Override + protected Mono doSuspend(TransactionSynchronizationManager synchronizationManager, Object transaction) + throws TransactionException { + + return Mono.defer(() -> { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + txObject.setConnectionHolder(null); + return Mono.justOrEmpty(synchronizationManager.unbindResource(obtainConnectionFactory())); + }); + } + + @Override + protected Mono doResume(TransactionSynchronizationManager synchronizationManager, Object transaction, + Object suspendedResources) throws TransactionException { + + return Mono.defer(() -> { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + txObject.setConnectionHolder(null); + synchronizationManager.bindResource(obtainConnectionFactory(), suspendedResources); + + return Mono.empty(); + }); + } + + @Override + protected Mono doCommit(TransactionSynchronizationManager TransactionSynchronizationManager, + GenericReactiveTransaction status) throws TransactionException { + + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction(); + Connection connection = txObject.getConnectionHolder().getConnection(); + if (status.isDebug()) { + logger.debug("Committing R2DBC transaction on Connection [" + connection + "]"); + } + + return Mono.from(connection.commitTransaction()) + .onErrorMap(R2dbcException.class, ex -> translateException("R2DBC commit", ex)); + } + + @Override + protected Mono doRollback(TransactionSynchronizationManager TransactionSynchronizationManager, + GenericReactiveTransaction status) throws TransactionException { + + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction(); + Connection connection = txObject.getConnectionHolder().getConnection(); + if (status.isDebug()) { + logger.debug("Rolling back R2DBC transaction on Connection [" + connection + "]"); + } + + return Mono.from(connection.rollbackTransaction()) + .onErrorMap(R2dbcException.class, ex -> translateException("R2DBC rollback", ex)); + } + + @Override + protected Mono doSetRollbackOnly(TransactionSynchronizationManager synchronizationManager, + GenericReactiveTransaction status) throws TransactionException { + + return Mono.fromRunnable(() -> { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) status.getTransaction(); + + if (status.isDebug()) { + logger + .debug("Setting R2DBC transaction [" + txObject.getConnectionHolder().getConnection() + "] rollback-only"); + } + txObject.setRollbackOnly(); + }); + } + + @Override + protected Mono doCleanupAfterCompletion(TransactionSynchronizationManager synchronizationManager, + Object transaction) { + + return Mono.defer(() -> { + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + + // Remove the connection holder from the context, if exposed. + if (txObject.isNewConnectionHolder()) { + synchronizationManager.unbindResource(obtainConnectionFactory()); + } + + // Reset connection. + Connection con = txObject.getConnectionHolder().getConnection(); + + Mono afterCleanup = Mono.empty(); + + if (txObject.isMustRestoreAutoCommit()) { + afterCleanup = afterCleanup.then(Mono.from(con.setAutoCommit(true))); + } + + if (txObject.getPreviousIsolationLevel() != null) { + afterCleanup = afterCleanup + .then(Mono.from(con.setTransactionIsolationLevel(txObject.getPreviousIsolationLevel()))); + } + + return afterCleanup.then(Mono.defer(() -> { + try { + if (txObject.isNewConnectionHolder()) { + if (logger.isDebugEnabled()) { + logger.debug("Releasing R2DBC Connection [" + con + "] after transaction"); + } + return ConnectionFactoryUtils.releaseConnection(con, obtainConnectionFactory()); + } + } + finally { + txObject.getConnectionHolder().clear(); + } + return Mono.empty(); + })); + }); + } + + /** + * Prepare the transactional {@link Connection} right after transaction begin. + *

+ * The default implementation executes a "SET TRANSACTION READ ONLY" statement if the {@link #setEnforceReadOnly + * "enforceReadOnly"} flag is set to {@code true} and the transaction definition indicates a read-only transaction. + *

+ * The "SET TRANSACTION READ ONLY" is understood by Oracle, MySQL and Postgres and may work with other databases as + * well. If you'd like to adapt this treatment, override this method accordingly. + * + * @param con the transactional R2DBC Connection + * @param definition the current transaction definition + * @param transaction the transaction object + * @see #setEnforceReadOnly + */ + protected Mono prepareTransactionalConnection(Connection con, TransactionDefinition definition, + Object transaction) { + + ConnectionFactoryTransactionObject txObject = (ConnectionFactoryTransactionObject) transaction; + + Mono prepare = Mono.empty(); + + if (isEnforceReadOnly() && definition.isReadOnly()) { + prepare = Mono.from(con.createStatement("SET TRANSACTION READ ONLY").execute()) + .flatMapMany(Result::getRowsUpdated) + .then(); + } + + // Apply specific isolation level, if any. + IsolationLevel isolationLevelToUse = resolveIsolationLevel(definition.getIsolationLevel()); + if (isolationLevelToUse != null && definition.getIsolationLevel() != TransactionDefinition.ISOLATION_DEFAULT) { + + if (logger.isDebugEnabled()) { + logger + .debug("Changing isolation level of R2DBC Connection [" + con + "] to " + isolationLevelToUse.asSql()); + } + IsolationLevel currentIsolation = con.getTransactionIsolationLevel(); + if (!currentIsolation.asSql().equalsIgnoreCase(isolationLevelToUse.asSql())) { + + txObject.setPreviousIsolationLevel(currentIsolation); + prepare = prepare.then(Mono.from(con.setTransactionIsolationLevel(isolationLevelToUse))); + } + } + + // Switch to manual commit if necessary. This is very expensive in some R2DBC drivers, + // so we don't want to do it unnecessarily (for example if we've explicitly + // configured the connection pool to set it already). + if (con.isAutoCommit()) { + txObject.setMustRestoreAutoCommit(true); + if (logger.isDebugEnabled()) { + logger.debug("Switching R2DBC Connection [" + con + "] to manual commit"); + } + prepare = prepare.then(Mono.from(con.setAutoCommit(false))); + } + + return prepare; + } + + /** + * Resolve the {@link TransactionDefinition#getIsolationLevel() isolation level constant} to a R2DBC + * {@link IsolationLevel}. If you'd like to extend isolation level translation for vendor-specific + * {@link IsolationLevel}s, override this method accordingly. + * + * @param isolationLevel the isolation level to translate. + * @return the resolved isolation level. Can be {@code null} if not resolvable or the isolation level should remain + * {@link TransactionDefinition#ISOLATION_DEFAULT default}. + * @see TransactionDefinition#getIsolationLevel() + */ + @Nullable + protected IsolationLevel resolveIsolationLevel(int isolationLevel) { + switch (isolationLevel) { + case TransactionDefinition.ISOLATION_READ_COMMITTED: + return IsolationLevel.READ_COMMITTED; + case TransactionDefinition.ISOLATION_READ_UNCOMMITTED: + return IsolationLevel.READ_UNCOMMITTED; + case TransactionDefinition.ISOLATION_REPEATABLE_READ: + return IsolationLevel.REPEATABLE_READ; + case TransactionDefinition.ISOLATION_SERIALIZABLE: + return IsolationLevel.SERIALIZABLE; + } + return null; + } + + /** + * Translate the given R2DBC commit/rollback exception to a common Spring exception to propagate from the + * {@link #commit}/{@link #rollback} call. + * + * @param task the task description (commit or rollback). + * @param ex the SQLException thrown from commit/rollback. + * @return the translated exception to emit + */ + protected RuntimeException translateException(String task, R2dbcException ex) { + return ConnectionFactoryUtils.convertR2dbcException(task, null, ex); + } + + + /** + * ConnectionFactory transaction object, representing a ConnectionHolder. Used as transaction object by + * ConnectionFactoryTransactionManager. + */ + private static class ConnectionFactoryTransactionObject { + + @Nullable + private ConnectionHolder connectionHolder; + + @Nullable + private IsolationLevel previousIsolationLevel; + + private boolean newConnectionHolder; + + private boolean mustRestoreAutoCommit; + + + void setConnectionHolder(@Nullable ConnectionHolder connectionHolder, boolean newConnectionHolder) { + setConnectionHolder(connectionHolder); + this.newConnectionHolder = newConnectionHolder; + } + + boolean isNewConnectionHolder() { + return this.newConnectionHolder; + } + + void setRollbackOnly() { + getConnectionHolder().setRollbackOnly(); + } + + public void setConnectionHolder(@Nullable ConnectionHolder connectionHolder) { + this.connectionHolder = connectionHolder; + } + + public ConnectionHolder getConnectionHolder() { + Assert.state(this.connectionHolder != null, "No ConnectionHolder available"); + return this.connectionHolder; + } + + public boolean hasConnectionHolder() { + return (this.connectionHolder != null); + } + + public void setPreviousIsolationLevel(@Nullable IsolationLevel previousIsolationLevel) { + this.previousIsolationLevel = previousIsolationLevel; + } + + @Nullable + public IsolationLevel getPreviousIsolationLevel() { + return this.previousIsolationLevel; + } + + public void setMustRestoreAutoCommit(boolean mustRestoreAutoCommit) { + this.mustRestoreAutoCommit = mustRestoreAutoCommit; + } + + public boolean isMustRestoreAutoCommit() { + return this.mustRestoreAutoCommit; + } + } + +} + diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/SingleConnectionFactory.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/SingleConnectionFactory.java new file mode 100644 index 000000000000..1970a6c67627 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/SingleConnectionFactory.java @@ -0,0 +1,296 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.concurrent.atomic.AtomicReference; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactories; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import io.r2dbc.spi.Wrapped; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link DelegatingConnectionFactory} that wraps a + * single R2DBC Connection which is not closed after use. + * Obviously, this is not multi-threading capable. + * + *

Note that at shutdown, someone should close the underlying + * Connection via the {@code close()} method. Client code will + * never call close on the Connection handle if it is + * SmartConnectionFactory-aware (e.g. uses + * {@link ConnectionFactoryUtils#releaseConnection(Connection, ConnectionFactory)}). + * + *

If client code will call {@link Connection#close()} in the + * assumption of a pooled Connection, like when using persistence tools, + * set "suppressClose" to "true". This will return a close-suppressing + * proxy instead of the physical Connection. + * + *

This is primarily intended for testing and pipelining usage of connections. + * For example, it enables easy testing outside an application server, for code + * that expects to work on a {@link ConnectionFactory}. + * Note that this implementation does not act as a connection pool-like utility. + * Connection pooling requires a {@link ConnectionFactory} implemented by e.g. + * {@code r2dbc-pool}. + * + * @author Mark Paluch + * @since 5.3 + * @see #create() + * @see Connection#close() + * @see ConnectionFactoryUtils#releaseConnection(Connection, ConnectionFactory) + */ +public class SingleConnectionFactory extends DelegatingConnectionFactory + implements DisposableBean { + + /** Create a close-suppressing proxy?. */ + private boolean suppressClose; + + /** Override auto-commit state?. */ + private @Nullable Boolean autoCommit; + + /** Wrapped Connection. */ + private final AtomicReference target = new AtomicReference<>(); + + /** Proxy Connection. */ + private @Nullable Connection connection; + + private final Mono connectionEmitter; + + + /** + * Constructor for bean-style configuration. + */ + public SingleConnectionFactory(ConnectionFactory targetConnectionFactory) { + super(targetConnectionFactory); + this.connectionEmitter = super.create().cache(); + } + + /** + * Create a new {@link SingleConnectionFactory} using a R2DBC connection URL. + * + * @param url the R2DBC URL to use for accessing {@link ConnectionFactory} discovery. + * @param suppressClose if the returned {@link Connection} should be a close-suppressing proxy or the physical + * {@link Connection}. + * @see ConnectionFactories#get(String) + */ + public SingleConnectionFactory(String url, boolean suppressClose) { + super(ConnectionFactories.get(url)); + this.suppressClose = suppressClose; + this.connectionEmitter = super.create().cache(); + } + + /** + * Create a new {@link SingleConnectionFactory} with a given {@link Connection} and + * {@link ConnectionFactoryMetadata}. + * + * @param target underlying target {@link Connection}. + * @param metadata {@link ConnectionFactory} metadata to be associated with this {@link ConnectionFactory}. + * @param suppressClose if the {@link Connection} should be wrapped with a {@link Connection} that suppresses + * {@code close()} calls (to allow for normal {@code close()} usage in applications that expect a pooled + * {@link Connection} but do not know our {@link SmartConnectionFactory} interface). + */ + public SingleConnectionFactory(Connection target, ConnectionFactoryMetadata metadata, + boolean suppressClose) { + super(new ConnectionFactory() { + @Override + public Publisher create() { + return Mono.just(target); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + return metadata; + } + }); + Assert.notNull(target, "Connection must not be null"); + Assert.notNull(metadata, "ConnectionFactoryMetadata must not be null"); + this.target.set(target); + this.connectionEmitter = Mono.just(target); + this.suppressClose = suppressClose; + this.connection = (suppressClose ? getCloseSuppressingConnectionProxy(target) : target); + } + + + /** + * Set whether the returned {@link Connection} should be a close-suppressing proxy or the physical {@link Connection}. + */ + public void setSuppressClose(boolean suppressClose) { + this.suppressClose = suppressClose; + } + + /** + * Return whether the returned {@link Connection} will be a close-suppressing proxy or the physical + * {@link Connection}. + */ + protected boolean isSuppressClose() { + return this.suppressClose; + } + + /** + * Set whether the returned {@link Connection}'s "autoCommit" setting should be overridden. + */ + public void setAutoCommit(boolean autoCommit) { + this.autoCommit = autoCommit; + } + + /** + * Return whether the returned {@link Connection}'s "autoCommit" setting should be overridden. + * + * @return the "autoCommit" value, or {@code null} if none to be applied + */ + @Nullable + protected Boolean getAutoCommitValue() { + return this.autoCommit; + } + + @Override + public Mono create() { + + Connection connection = this.target.get(); + + return this.connectionEmitter.map(connectionToUse -> { + + if (connection == null) { + this.target.compareAndSet(connection, connectionToUse); + this.connection = (isSuppressClose() ? getCloseSuppressingConnectionProxy(connectionToUse) : connectionToUse); + } + + return this.connection; + }).flatMap(this::prepareConnection); + } + + /** + * Close the underlying {@link Connection}. The provider of this {@link ConnectionFactory} needs to care for proper + * shutdown. + *

+ * As this bean implements {@link DisposableBean}, a bean factory will automatically invoke this on destruction of its + * cached singletons. + */ + @Override + public void destroy() { + resetConnection().block(); + } + + /** + * Reset the underlying shared Connection, to be reinitialized on next access. + */ + public Mono resetConnection() { + + Connection connection = this.target.get(); + + if (connection == null) { + return Mono.empty(); + } + + return Mono.defer(() -> { + + if (this.target.compareAndSet(connection, null)) { + + this.connection = null; + + return Mono.from(connection.close()); + } + + return Mono.empty(); + }); + } + + /** + * Prepare the {@link Connection} before using it. Applies {@link #getAutoCommitValue() auto-commit} settings if + * configured. + * + * @param connection the requested {@link Connection}. + * @return the prepared {@link Connection}. + */ + protected Mono prepareConnection(Connection connection) { + + Boolean autoCommit = getAutoCommitValue(); + if (autoCommit != null) { + return Mono.from(connection.setAutoCommit(autoCommit)).thenReturn(connection); + } + + return Mono.just(connection); + } + + /** + * Wrap the given {@link Connection} with a proxy that delegates every method call to it but suppresses close calls. + * + * @param target the original {@link Connection} to wrap. + * @return the wrapped Connection. + */ + protected Connection getCloseSuppressingConnectionProxy(Connection target) { + return (Connection) Proxy.newProxyInstance(SingleConnectionFactory.class.getClassLoader(), + new Class[] { Connection.class, Wrapped.class }, new CloseSuppressingInvocationHandler(target)); + } + + + /** + * Invocation handler that suppresses close calls on R2DBC Connections. + * + * @see Connection#close() + */ + private static class CloseSuppressingInvocationHandler implements InvocationHandler { + + private final Connection target; + + CloseSuppressingInvocationHandler(Connection target) { + this.target = target; + } + + @Override + @Nullable + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + // Invocation on ConnectionProxy interface coming in... + + if (method.getName().equals("equals")) { + // Only consider equal when proxies are identical. + return proxy == args[0]; + } + else if (method.getName().equals("hashCode")) { + // Use hashCode of PersistenceManager proxy. + return System.identityHashCode(proxy); + } + else if (method.getName().equals("unwrap")) { + return this.target; + } + else if (method.getName().equals("close")) { + // Handle close method: suppress, not valid. + return Mono.empty(); + } + + // Invoke method on target Connection. + try { + return method.invoke(this.target, args); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxy.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxy.java new file mode 100644 index 000000000000..9cbd9083b71b --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxy.java @@ -0,0 +1,198 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Wrapped; +import reactor.core.publisher.Mono; + +import org.springframework.lang.Nullable; +import org.springframework.util.ReflectionUtils; + +/** + * Proxy for a target R2DBC {@link ConnectionFactory}, adding awareness + * of Spring-managed transactions. + * + *

Data access code that should remain unaware of Spring's data access + * support can work with this proxy to seamlessly participate in + * Spring-managed transactions. + * Note that the transaction manager, for example {@link R2dbcTransactionManager}, + * still needs to work with the underlying {@link ConnectionFactory}, + * not with this proxy. + * + *

Make sure that {@link TransactionAwareConnectionFactoryProxy} is the outermost + * {@link ConnectionFactory} of a chain of {@link ConnectionFactory} proxies/adapters. + * {@link TransactionAwareConnectionFactoryProxy} can delegate either directly to the + * target connection pool or to some intermediary proxy/adapter. + * + *

Delegates to {@link ConnectionFactoryUtils} for automatically participating + * in thread-bound transactions, for example managed by {@link R2dbcTransactionManager}. + * {@link #create()} calls and {@code close} calls on returned {@link Connection} + * will behave properly within a transaction, i.e. always operate on the + * transactional Connection. If not within a transaction, normal {@link ConnectionFactory} + * behavior applies. + * + *

This proxy allows data access code to work with the plain R2DBC API. However, + * if possible, use Spring's {@link ConnectionFactoryUtils} or {@code DatabaseClient} + * to get transaction participation even without a proxy for the target + * {@link ConnectionFactory}, avoiding the need to define such a proxy in the first place. + * + *

NOTE: This {@link ConnectionFactory} proxy needs to return wrapped + * {@link Connection}s (which implement the {@link ConnectionProxy} interface) in order + * to handle close calls properly. Use {@link Wrapped#unwrap()} to retrieve + * the native R2DBC Connection. + * + * @author Mark Paluch + * @author Christoph Strobl + * @since 5.3 + * @see ConnectionFactory#create + * @see Connection#close + * @see ConnectionFactoryUtils#doGetConnection + * @see ConnectionFactoryUtils#doReleaseConnection + */ +public class TransactionAwareConnectionFactoryProxy extends DelegatingConnectionFactory { + + /** + * Create a new {@link TransactionAwareConnectionFactoryProxy}. + * + * @param targetConnectionFactory the target {@link ConnectionFactory}. + * @throws IllegalArgumentException if given {@link ConnectionFactory} is {@code null}. + */ + public TransactionAwareConnectionFactoryProxy(ConnectionFactory targetConnectionFactory) { + super(targetConnectionFactory); + } + + + /** + * Delegates to {@link ConnectionFactoryUtils} for automatically participating in Spring-managed transactions. + *

+ * The returned {@link ConnectionFactory} handle implements the {@link ConnectionProxy} interface, allowing to + * retrieve the underlying target {@link Connection}. + * + * @return a transactional {@link Connection} if any, a new one else. + * @see ConnectionFactoryUtils#doGetConnection + * @see ConnectionProxy#getTargetConnection + */ + @Override + public Mono create() { + return getTransactionAwareConnectionProxy(obtainTargetConnectionFactory()); + } + + /** + * Wraps the given {@link Connection} with a proxy that delegates every method call to it but delegates + * {@code close()} calls to {@link ConnectionFactoryUtils}. + * + * @param targetConnectionFactory the {@link ConnectionFactory} that the {@link Connection} came from. + * @return the wrapped {@link Connection}. + * @see Connection#close() + * @see ConnectionFactoryUtils#doReleaseConnection + */ + protected Mono getTransactionAwareConnectionProxy(ConnectionFactory targetConnectionFactory) { + return ConnectionFactoryUtils.getConnection(targetConnectionFactory) + .map(connection -> proxyConnection(connection, targetConnectionFactory)); + } + + private static Connection proxyConnection(Connection connection, ConnectionFactory targetConnectionFactory) { + + return (Connection) Proxy.newProxyInstance(TransactionAwareConnectionFactoryProxy.class.getClassLoader(), + new Class[] { Connection.class, Wrapped.class }, + new TransactionAwareInvocationHandler(connection, targetConnectionFactory)); + } + + + /** + * Invocation handler that delegates close calls on R2DBC Connections to {@link ConnectionFactoryUtils} for being + * aware of context-bound transactions. + */ + private static class TransactionAwareInvocationHandler implements InvocationHandler { + + private final Connection connection; + + private final ConnectionFactory targetConnectionFactory; + + private boolean closed = false; + + TransactionAwareInvocationHandler(Connection connection, ConnectionFactory targetConnectionFactory) { + this.connection = connection; + this.targetConnectionFactory = targetConnectionFactory; + } + + @Override + @Nullable + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (ReflectionUtils.isObjectMethod(method)) { + + if (ReflectionUtils.isToStringMethod(method)) { + return proxyToString(proxy); + } + + if (ReflectionUtils.isEqualsMethod(method)) { + return (proxy == args[0]); + } + + if (ReflectionUtils.isHashCodeMethod(method)) { + return System.identityHashCode(proxy); + } + } + + // Invocation on ConnectionProxy interface coming in... + switch (method.getName()) { + + case "unwrap": + return this.connection; + case "close": + // Handle close method: only close if not within a transaction. + return ConnectionFactoryUtils.doReleaseConnection(this.connection, this.targetConnectionFactory) + .doOnSubscribe(n -> this.closed = true); + case "isClosed": + return this.closed; + } + + if (this.closed) { + throw new IllegalStateException("Connection handle already closed"); + } + + // Invoke method on target Connection. + try { + return method.invoke(this.connection, args); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + } + + private String proxyToString(@Nullable Object proxy) { + // Allow for differentiating between the proxy and the raw Connection. + StringBuilder sb = new StringBuilder("Transaction-aware proxy for target Connection "); + if (this.connection != null) { + sb.append("[").append(this.connection.toString()).append("]"); + } + else { + sb.append(" from ConnectionFactory [").append(this.targetConnectionFactory).append("]"); + } + return sb.toString(); + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CannotReadScriptException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CannotReadScriptException.java new file mode 100644 index 000000000000..68b9336b273f --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CannotReadScriptException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import org.springframework.core.io.support.EncodedResource; + +/** + * Thrown by {@link ScriptUtils} if an SQL script cannot be read. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class CannotReadScriptException extends ScriptException { + + /** + * Create a new {@code CannotReadScriptException}. + * @param resource the resource that cannot be read from. + * @param cause the underlying cause of the resource access failure. + */ + public CannotReadScriptException(EncodedResource resource, Throwable cause) { + super("Cannot read SQL script from " + resource, cause); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulator.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulator.java new file mode 100644 index 000000000000..7082ad0f5a4d --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulator.java @@ -0,0 +1,92 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import io.r2dbc.spi.Connection; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.util.Assert; + +/** + * Composite {@link DatabasePopulator} that delegates to a list of given + * {@link DatabasePopulator} implementations, executing all scripts. + * + * @author Mark Paluch + * @since 5.3 + */ +public class CompositeDatabasePopulator implements DatabasePopulator { + + private final List populators = new ArrayList<>(4); + + + /** + * Create an empty {@code CompositeDatabasePopulator}. + * @see #setPopulators + * @see #addPopulators + */ + public CompositeDatabasePopulator() {} + + /** + * Create a {@code CompositeDatabasePopulator}. with the given populators. + * @param populators one or more populators to delegate to. + */ + public CompositeDatabasePopulator(Collection populators) { + Assert.notNull(populators, "Collection of DatabasePopulator must not be null"); + this.populators.addAll(populators); + } + + /** + * Create a {@code CompositeDatabasePopulator} with the given populators. + * @param populators one or more populators to delegate to. + */ + public CompositeDatabasePopulator(DatabasePopulator... populators) { + Assert.notNull(populators, "DatabasePopulators must not be null"); + this.populators.addAll(Arrays.asList(populators)); + } + + + /** + * Specify one or more populators to delegate to. + */ + public void setPopulators(DatabasePopulator... populators) { + Assert.notNull(populators, "DatabasePopulators must not be null"); + this.populators.clear(); + this.populators.addAll(Arrays.asList(populators)); + } + + /** + * Add one or more populators to the list of delegates. + */ + public void addPopulators(DatabasePopulator... populators) { + Assert.notNull(populators, "DatabasePopulators must not be null"); + this.populators.addAll(Arrays.asList(populators)); + } + + @Override + public Mono populate(Connection connection) throws ScriptException { + Assert.notNull(connection, "Connection must not be null"); + return Flux.fromIterable(this.populators).concatMap(populator -> populator.populate(connection)) + .then(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializer.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializer.java new file mode 100644 index 000000000000..8fa8fc04d242 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializer.java @@ -0,0 +1,116 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Used to {@linkplain #setDatabasePopulator set up} a database during + * initialization and {@link #setDatabaseCleaner clean up} a database during + * destruction. + * + * @author Mark Paluch + * @since 5.3 + * @see DatabasePopulator + */ +public class ConnectionFactoryInitializer implements InitializingBean, DisposableBean { + + @Nullable + private ConnectionFactory connectionFactory; + + @Nullable + private DatabasePopulator databasePopulator; + + @Nullable + private DatabasePopulator databaseCleaner; + + private boolean enabled = true; + + + /** + * The {@link ConnectionFactory} for the database to populate when this component is initialized and to clean up when + * this component is shut down. + *

+ * This property is mandatory with no default provided. + * + * @param connectionFactory the R2DBC {@link ConnectionFactory}. + */ + public void setConnectionFactory(ConnectionFactory connectionFactory) { + this.connectionFactory = connectionFactory; + } + + /** + * Set the {@link DatabasePopulator} to execute during the bean initialization phase. + * + * @param databasePopulator the {@link DatabasePopulator} to use during initialization + * @see #setDatabaseCleaner + */ + public void setDatabasePopulator(DatabasePopulator databasePopulator) { + this.databasePopulator = databasePopulator; + } + + /** + * Set the {@link DatabasePopulator} to execute during the bean destruction phase, cleaning up the database and + * leaving it in a known state for others. + * + * @param databaseCleaner the {@link DatabasePopulator} to use during destruction + * @see #setDatabasePopulator + */ + public void setDatabaseCleaner(DatabasePopulator databaseCleaner) { + this.databaseCleaner = databaseCleaner; + } + + /** + * Flag to explicitly enable or disable the {@link #setDatabasePopulator database populator} and + * {@link #setDatabaseCleaner database cleaner}. + * + * @param enabled {@code true} if the database populator and database cleaner should be called on startup and + * shutdown, respectively + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + /** + * Use the {@link #setDatabasePopulator database populator} to set up the database. + */ + @Override + public void afterPropertiesSet() { + execute(this.databasePopulator); + } + + /** + * Use the {@link #setDatabaseCleaner database cleaner} to clean up the database. + */ + @Override + public void destroy() { + execute(this.databaseCleaner); + } + + private void execute(@Nullable DatabasePopulator populator) { + Assert.state(this.connectionFactory != null, "ConnectionFactory must be set"); + if (this.enabled && populator != null) { + populator.populate(this.connectionFactory).block(); + } + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/DatabasePopulator.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/DatabasePopulator.java new file mode 100644 index 000000000000..521aa36d8ed7 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/DatabasePopulator.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Mono; + +import org.springframework.dao.DataAccessException; +import org.springframework.r2dbc.connection.ConnectionFactoryUtils; +import org.springframework.util.Assert; + +/** + * Strategy used to populate, initialize, or clean up a database. + * + * @author Mark Paluch + * @since 5.3 + * @see ResourceDatabasePopulator + * @see ConnectionFactoryInitializer + */ +@FunctionalInterface +public interface DatabasePopulator { + + /** + * Populate, initialize, or clean up the database using the + * provided R2DBC {@link Connection}. + * + * @param connection the R2DBC connection to use to populate the db; + * already configured and ready to use, must not be {@code null} + * @return {@link Mono} that initiates script execution and is + * notified upon completion + * @throws ScriptException in all other error cases + */ + Mono populate(Connection connection) throws ScriptException; + + /** + * Execute the given {@link DatabasePopulator} against the given {@link ConnectionFactory}. + * @param connectionFactory the {@link ConnectionFactory} to execute against + * @return {@link Mono} that initiates {@link DatabasePopulator#populate(Connection)} + * and is notified upon completion + */ + default Mono populate(ConnectionFactory connectionFactory) + throws DataAccessException { + Assert.notNull(connectionFactory, "ConnectionFactory must not be null"); + return Mono.usingWhen(ConnectionFactoryUtils.getConnection(connectionFactory), // + this::populate, // + connection -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory), // + (connection, err) -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory), + connection -> ConnectionFactoryUtils.releaseConnection(connection, connectionFactory)) + .onErrorMap(ex -> !(ex instanceof ScriptException), + ex -> new UncategorizedScriptException("Failed to execute database script", ex)); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulator.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulator.java new file mode 100644 index 000000000000..9e545e6dfe36 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulator.java @@ -0,0 +1,273 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import io.r2dbc.spi.Connection; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.support.EncodedResource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Populates, initializes, or cleans up a database using SQL + * scripts defined in external resources. + *

    + *
  • Call {@link #addScript} to add a single SQL script location. + *
  • Call {@link #addScripts} to add multiple SQL script locations. + *
  • Consult the setter methods in this class for further configuration options. + *
  • Call {@link #populate} to initialize or clean up the database using the configured scripts. + *
+ * + * @author Keith Donald + * @author Dave Syer + * @author Juergen Hoeller + * @author Chris Beams + * @author Oliver Gierke + * @author Sam Brannen + * @author Chris Baldwin + * @author Phillip Webb + * @author Mark Paluch + * @since 5.3 + * @see ScriptUtils + */ +public class ResourceDatabasePopulator implements DatabasePopulator { + + List scripts = new ArrayList<>(); + + @Nullable + private Charset sqlScriptEncoding; + + private String separator = ScriptUtils.DEFAULT_STATEMENT_SEPARATOR; + + private String[] commentPrefixes = ScriptUtils.DEFAULT_COMMENT_PREFIXES; + + private String blockCommentStartDelimiter = ScriptUtils.DEFAULT_BLOCK_COMMENT_START_DELIMITER; + + private String blockCommentEndDelimiter = ScriptUtils.DEFAULT_BLOCK_COMMENT_END_DELIMITER; + + private boolean continueOnError = false; + + private boolean ignoreFailedDrops = false; + + private DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory(); + + + /** + * Create a new {@code ResourceDatabasePopulator} with default settings. + */ + public ResourceDatabasePopulator() { + } + + /** + * Create a new {@code ResourceDatabasePopulator} with default settings for the supplied scripts. + * @param scripts the scripts to execute to initialize or clean up the database (never {@code null}) + */ + public ResourceDatabasePopulator(Resource... scripts) { + setScripts(scripts); + } + + /** + * Construct a new {@code ResourceDatabasePopulator} with the supplied values. + * @param continueOnError flag to indicate that all failures in SQL should be + * logged but not cause a failure + * @param ignoreFailedDrops flag to indicate that a failed SQL {@code DROP} + * statement can be ignored + * @param sqlScriptEncoding the encoding for the supplied SQL scripts + * (may be {@code null} or empty to indicate platform encoding) + * @param scripts the scripts to execute to initialize or clean up the database + * (never {@code null}) + */ + public ResourceDatabasePopulator(boolean continueOnError, boolean ignoreFailedDrops, + @Nullable String sqlScriptEncoding, Resource... scripts) { + + this.continueOnError = continueOnError; + this.ignoreFailedDrops = ignoreFailedDrops; + setSqlScriptEncoding(sqlScriptEncoding); + setScripts(scripts); + } + + + /** + * Add a script to execute to initialize or clean up the database. + * @param script the path to an SQL script (never {@code null}) + */ + public void addScript(Resource script) { + Assert.notNull(script, "'script' must not be null"); + this.scripts.add(script); + } + + /** + * Add multiple scripts to execute to initialize or clean up the database. + * @param scripts the scripts to execute (never {@code null}) + */ + public void addScripts(Resource... scripts) { + assertContentsOfScriptArray(scripts); + this.scripts.addAll(Arrays.asList(scripts)); + } + + /** + * Set the scripts to execute to initialize or clean up the database, + * replacing any previously added scripts. + * @param scripts the scripts to execute (never {@code null}) + */ + public void setScripts(Resource... scripts) { + assertContentsOfScriptArray(scripts); + // Ensure that the list is modifiable + this.scripts = new ArrayList<>(Arrays.asList(scripts)); + } + + private void assertContentsOfScriptArray(Resource... scripts) { + Assert.notNull(scripts, "'scripts' must not be null"); + Assert.noNullElements(scripts, "'scripts' must not contain null elements"); + } + + /** + * Specify the encoding for the configured SQL scripts, + * if different from the platform encoding. + * @param sqlScriptEncoding the encoding used in scripts + * (may be {@code null} or empty to indicate platform encoding) + * @see #addScript(Resource) + */ + public void setSqlScriptEncoding(@Nullable String sqlScriptEncoding) { + setSqlScriptEncoding(StringUtils.hasText(sqlScriptEncoding) ? Charset.forName(sqlScriptEncoding) : null); + } + + /** + * Specify the encoding for the configured SQL scripts, + * if different from the platform encoding. + * @param sqlScriptEncoding the encoding used in scripts + * (may be {@code null} or empty to indicate platform encoding) + * @see #addScript(Resource) + */ + public void setSqlScriptEncoding(@Nullable Charset sqlScriptEncoding) { + this.sqlScriptEncoding = sqlScriptEncoding; + } + + /** + * Specify the statement separator, if a custom one. + *

Defaults to {@code ";"} if not specified and falls back to {@code "\n"} + * as a last resort; may be set to {@link ScriptUtils#EOF_STATEMENT_SEPARATOR} + * to signal that each script contains a single statement without a separator. + * @param separator the script statement separator + */ + public void setSeparator(String separator) { + this.separator = separator; + } + + /** + * Set the prefix that identifies single-line comments within the SQL scripts. + *

Defaults to {@code "--"}. + * @param commentPrefix the prefix for single-line comments + * @see #setCommentPrefixes(String...) + */ + public void setCommentPrefix(String commentPrefix) { + Assert.hasText(commentPrefix, "'commentPrefix' must not be null or empty"); + this.commentPrefixes = new String[] { commentPrefix }; + } + + /** + * Set the prefixes that identify single-line comments within the SQL scripts. + *

Defaults to {@code ["--"]}. + * @param commentPrefixes the prefixes for single-line comments + */ + public void setCommentPrefixes(String... commentPrefixes) { + Assert.notEmpty(commentPrefixes, "'commentPrefixes' must not be null or empty"); + Assert.noNullElements(commentPrefixes, "'commentPrefixes' must not contain null elements"); + this.commentPrefixes = commentPrefixes; + } + + /** + * Set the start delimiter that identifies block comments within the SQL + * scripts. + *

Defaults to {@code "/*"}. + * @param blockCommentStartDelimiter the start delimiter for block comments + * (never {@code null} or empty) + * @see #setBlockCommentEndDelimiter + */ + public void setBlockCommentStartDelimiter(String blockCommentStartDelimiter) { + Assert.hasText(blockCommentStartDelimiter, "'blockCommentStartDelimiter' must not be null or empty"); + this.blockCommentStartDelimiter = blockCommentStartDelimiter; + } + + /** + * Set the end delimiter that identifies block comments within the SQL + * scripts. + *

Defaults to "*/". + * @param blockCommentEndDelimiter the end delimiter for block comments + * (never {@code null} or empty) + * @see #setBlockCommentStartDelimiter + */ + public void setBlockCommentEndDelimiter(String blockCommentEndDelimiter) { + Assert.hasText(blockCommentEndDelimiter, "'blockCommentEndDelimiter' must not be null or empty"); + this.blockCommentEndDelimiter = blockCommentEndDelimiter; + } + + /** + * Flag to indicate that all failures in SQL should be logged but not cause a failure. + *

Defaults to {@code false}. + * @param continueOnError {@code true} if script execution should continue on error + */ + public void setContinueOnError(boolean continueOnError) { + this.continueOnError = continueOnError; + } + + /** + * Flag to indicate that a failed SQL {@code DROP} statement can be ignored. + *

This is useful for a non-embedded database whose SQL dialect does not + * support an {@code IF EXISTS} clause in a {@code DROP} statement. + *

The default is {@code false} so that if the populator runs accidentally, it will + * fail fast if a script starts with a {@code DROP} statement. + * @param ignoreFailedDrops {@code true} if failed drop statements should be ignored + */ + public void setIgnoreFailedDrops(boolean ignoreFailedDrops) { + this.ignoreFailedDrops = ignoreFailedDrops; + } + + /** + * Set the {@link DataBufferFactory} to use for {@link Resource} loading. + *

Defaults to {@link DefaultDataBufferFactory}. + * @param dataBufferFactory the {@link DataBufferFactory} to use, must not be {@code null} + */ + public void setDataBufferFactory(DataBufferFactory dataBufferFactory) { + Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + this.dataBufferFactory = dataBufferFactory; + } + + + @Override + public Mono populate(Connection connection) throws ScriptException { + Assert.notNull(connection, "Connection must not be null"); + return Flux.fromIterable(this.scripts).concatMap(resource -> { + EncodedResource encodedScript = new EncodedResource(resource, this.sqlScriptEncoding); + return ScriptUtils.executeSqlScript(connection, encodedScript, this.dataBufferFactory, this.continueOnError, + this.ignoreFailedDrops, this.commentPrefixes, this.separator, this.blockCommentStartDelimiter, + this.blockCommentEndDelimiter); + }).then(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptException.java new file mode 100644 index 000000000000..dd6b16269374 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import org.springframework.dao.DataAccessException; +import org.springframework.lang.Nullable; + +/** + * Root of the hierarchy of data access exceptions that are related to processing of SQL scripts. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public abstract class ScriptException extends DataAccessException { + + /** + * Create a new {@code ScriptException}. + * @param message the detail message + */ + public ScriptException(String message) { + super(message); + } + + /** + * Create a new {@code ScriptException}. + * @param message the detail message + * @param cause the root cause + */ + public ScriptException(String message, @Nullable Throwable cause) { + super(message, cause); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptParseException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptParseException.java new file mode 100644 index 000000000000..54cc2981785b --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptParseException.java @@ -0,0 +1,56 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import org.springframework.core.io.support.EncodedResource; +import org.springframework.lang.Nullable; + +/** + * Thrown by {@link ScriptUtils} if an SQL script cannot be properly parsed. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class ScriptParseException extends ScriptException { + + /** + * Create a new {@code ScriptParseException}. + * @param message detailed message + * @param resource the resource from which the SQL script was read + */ + public ScriptParseException(String message, @Nullable EncodedResource resource) { + super(buildMessage(message, resource)); + } + + /** + * Create a new {@code ScriptParseException}. + * @param message detailed message + * @param resource the resource from which the SQL script was read + * @param cause the underlying cause of the failure + */ + public ScriptParseException(String message, @Nullable EncodedResource resource, @Nullable Throwable cause) { + super(buildMessage(message, resource), cause); + } + + + private static String buildMessage(String message, @Nullable EncodedResource resource) { + return String.format("Failed to parse SQL script from resource [%s]: %s", + (resource == null ? "" : resource), message); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptStatementFailedException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptStatementFailedException.java new file mode 100644 index 000000000000..7fc84fc98477 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptStatementFailedException.java @@ -0,0 +1,58 @@ +/* + * Copyright 2002-2015 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import org.springframework.core.io.support.EncodedResource; + +/** + * Thrown by {@link ScriptUtils} if a statement in an SQL script failed when + * executing it against the target database. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class ScriptStatementFailedException extends ScriptException { + + /** + * Create a new {@code ScriptStatementFailedException}. + * @param stmt the actual SQL statement that failed + * @param stmtNumber the statement number in the SQL script (i.e., + * the nth statement present in the resource) + * @param encodedResource the resource from which the SQL statement was read + * @param cause the underlying cause of the failure + */ + public ScriptStatementFailedException(String stmt, int stmtNumber, EncodedResource encodedResource, Throwable cause) { + super(buildErrorMessage(stmt, stmtNumber, encodedResource), cause); + } + + + /** + * Build an error message for an SQL script execution failure, + * based on the supplied arguments. + * @param stmt the actual SQL statement that failed + * @param stmtNumber the statement number in the SQL script (i.e., + * the nth statement present in the resource) + * @param encodedResource the resource from which the SQL statement was read + * @return an error message suitable for an exception's detail message + * or logging + */ + public static String buildErrorMessage(String stmt, int stmtNumber, EncodedResource encodedResource) { + return String.format("Failed to execute SQL script statement #%s of %s: %s", stmtNumber, encodedResource, stmt); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java new file mode 100644 index 000000000000..fb7788c964c6 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/ScriptUtils.java @@ -0,0 +1,666 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.LineNumberReader; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.Result; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.Resource; +import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.support.EncodedResource; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Generic utility methods for working with SQL scripts. + *

Mainly for internal use within the framework. + * + * @author Thomas Risberg + * @author Sam Brannen + * @author Juergen Hoeller + * @author Keith Donald + * @author Dave Syer + * @author Chris Beams + * @author Oliver Gierke + * @author Chris Baldwin + * @author Nicolas Debeissat + * @author Phillip Webb + * @author Mark Paluch + * @since 5.3 + */ +public abstract class ScriptUtils { + + /** + * Default statement separator within SQL scripts: {@code ";"}. + */ + public static final String DEFAULT_STATEMENT_SEPARATOR = ";"; + + /** + * Fallback statement separator within SQL scripts: {@code "\n"}. + *

Used if neither a custom separator nor the + * {@link #DEFAULT_STATEMENT_SEPARATOR} is present in a given script. + */ + public static final String FALLBACK_STATEMENT_SEPARATOR = "\n"; + + /** + * End of file (EOF) SQL statement separator: {@code "^^^ END OF SCRIPT ^^^"}. + *

This value may be supplied as the {@code separator} to {@link + * #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String)} + * to denote that an SQL script contains a single statement (potentially + * spanning multiple lines) with no explicit statement separator. Note that + * such a script should not actually contain this value; it is merely a + * virtual statement separator. + */ + public static final String EOF_STATEMENT_SEPARATOR = "^^^ END OF SCRIPT ^^^"; + + /** + * Default prefix for single-line comments within SQL scripts: {@code "--"}. + */ + public static final String DEFAULT_COMMENT_PREFIX = "--"; + + /** + * Default prefixes for single-line comments within SQL scripts: {@code ["--"]}. + */ + public static final String[] DEFAULT_COMMENT_PREFIXES = {DEFAULT_COMMENT_PREFIX}; + + /** + * Default start delimiter for block comments within SQL scripts: {@code "/*"}. + */ + public static final String DEFAULT_BLOCK_COMMENT_START_DELIMITER = "/*"; + + /** + * Default end delimiter for block comments within SQL scripts: "*/". + */ + public static final String DEFAULT_BLOCK_COMMENT_END_DELIMITER = "*/"; + + + private static final Log logger = LogFactory.getLog(ScriptUtils.class); + + // utility constructor + private ScriptUtils() {} + + /** + * Split an SQL script into separate statements delimited by the provided + * separator character. Each individual statement will be added to the + * provided {@code List}. + *

Within the script, {@value #DEFAULT_COMMENT_PREFIX} will be used as the + * comment prefix; any text beginning with the comment prefix and extending to + * the end of the line will be omitted from the output. Similarly, + * {@value #DEFAULT_BLOCK_COMMENT_START_DELIMITER} and + * {@value #DEFAULT_BLOCK_COMMENT_END_DELIMITER} will be used as the + * start and end block comment delimiters: any text enclosed + * in a block comment will be omitted from the output. In addition, multiple + * adjacent whitespace characters will be collapsed into a single space. + * @param script the SQL script + * @param separator character separating each statement (typically a ';') + * @param statements the list that will contain the individual statements + * @throws ScriptException if an error occurred while splitting the SQL script + * @see #splitSqlScript(String, String, List) + * @see #splitSqlScript(EncodedResource, String, String, String, String, String, List) + */ + public static void splitSqlScript(String script, char separator, List statements) throws ScriptException { + splitSqlScript(script, String.valueOf(separator), statements); + } + + /** + * Split an SQL script into separate statements delimited by the provided + * separator string. Each individual statement will be added to the + * provided {@code List}. + *

Within the script, {@value #DEFAULT_COMMENT_PREFIX} will be used as the + * comment prefix; any text beginning with the comment prefix and extending to + * the end of the line will be omitted from the output. Similarly, + * {@value #DEFAULT_BLOCK_COMMENT_START_DELIMITER} and + * {@value #DEFAULT_BLOCK_COMMENT_END_DELIMITER} will be used as the + * start and end block comment delimiters: any text enclosed + * in a block comment will be omitted from the output. In addition, multiple + * adjacent whitespace characters will be collapsed into a single space. + * @param script the SQL script + * @param separator text separating each statement + * (typically a ';' or newline character) + * @param statements the list that will contain the individual statements + * @throws ScriptException if an error occurred while splitting the SQL script + * @see #splitSqlScript(String, char, List) + * @see #splitSqlScript(EncodedResource, String, String, String, String, String, List) + */ + public static void splitSqlScript(String script, String separator, List statements) throws ScriptException { + splitSqlScript(null, script, separator, DEFAULT_COMMENT_PREFIX, DEFAULT_BLOCK_COMMENT_START_DELIMITER, + DEFAULT_BLOCK_COMMENT_END_DELIMITER, statements); + } + + /** + * Split an SQL script into separate statements delimited by the provided + * separator string. Each individual statement will be added to the provided + * {@code List}. + *

Within the script, the provided {@code commentPrefix} will be honored: + * any text beginning with the comment prefix and extending to the end of the + * line will be omitted from the output. Similarly, the provided + * {@code blockCommentStartDelimiter} and {@code blockCommentEndDelimiter} + * delimiters will be honored: any text enclosed in a block comment will be + * omitted from the output. In addition, multiple adjacent whitespace characters + * will be collapsed into a single space. + * @param resource the resource from which the script was read + * @param script the SQL script + * @param separator text separating each statement + * (typically a ';' or newline character) + * @param commentPrefix the prefix that identifies SQL line comments + * (typically "--") + * @param blockCommentStartDelimiter the start block comment delimiter; + * never {@code null} or empty + * @param blockCommentEndDelimiter the end block comment delimiter; + * never {@code null} or empty + * @param statements the list that will contain the individual statements + * @throws ScriptException if an error occurred while splitting the SQL script + */ + public static void splitSqlScript(@Nullable EncodedResource resource, String script, + String separator, String commentPrefix, String blockCommentStartDelimiter, + String blockCommentEndDelimiter, List statements) throws ScriptException { + + Assert.hasText(commentPrefix, "'commentPrefix' must not be null or empty"); + splitSqlScript(resource, script, separator, new String[] { commentPrefix }, + blockCommentStartDelimiter, blockCommentEndDelimiter, statements); + } + + /** + * Split an SQL script into separate statements delimited by the provided + * separator string. Each individual statement will be added to the provided + * {@code List}. + *

Within the script, the provided {@code commentPrefixes} will be honored: + * any text beginning with one of the comment prefixes and extending to the + * end of the line will be omitted from the output. Similarly, the provided + * {@code blockCommentStartDelimiter} and {@code blockCommentEndDelimiter} + * delimiters will be honored: any text enclosed in a block comment will be + * omitted from the output. In addition, multiple adjacent whitespace characters + * will be collapsed into a single space. + * @param resource the resource from which the script was read + * @param script the SQL script + * @param separator text separating each statement + * (typically a ';' or newline character) + * @param commentPrefixes the prefixes that identify SQL line comments + * (typically "--") + * @param blockCommentStartDelimiter the start block comment delimiter; + * never {@code null} or empty + * @param blockCommentEndDelimiter the end block comment delimiter; + * never {@code null} or empty + * @param statements the list that will contain the individual statements + * @throws ScriptException if an error occurred while splitting the SQL script + */ + public static void splitSqlScript(@Nullable EncodedResource resource, String script, + String separator, String[] commentPrefixes, String blockCommentStartDelimiter, + String blockCommentEndDelimiter, List statements) throws ScriptException { + + Assert.hasText(script, "'script' must not be null or empty"); + Assert.notNull(separator, "'separator' must not be null"); + Assert.notEmpty(commentPrefixes, "'commentPrefixes' must not be null or empty"); + for (String commentPrefix : commentPrefixes) { + Assert.hasText(commentPrefix, "'commentPrefixes' must not contain null or empty elements"); + } + Assert.hasText(blockCommentStartDelimiter, "'blockCommentStartDelimiter' must not be null or empty"); + Assert.hasText(blockCommentEndDelimiter, "'blockCommentEndDelimiter' must not be null or empty"); + + StringBuilder sb = new StringBuilder(); + boolean inSingleQuote = false; + boolean inDoubleQuote = false; + boolean inEscape = false; + + for (int i = 0; i < script.length(); i++) { + char c = script.charAt(i); + if (inEscape) { + inEscape = false; + sb.append(c); + continue; + } + // MySQL style escapes + if (c == '\\') { + inEscape = true; + sb.append(c); + continue; + } + if (!inDoubleQuote && (c == '\'')) { + inSingleQuote = !inSingleQuote; + } + else if (!inSingleQuote && (c == '"')) { + inDoubleQuote = !inDoubleQuote; + } + if (!inSingleQuote && !inDoubleQuote) { + if (script.startsWith(separator, i)) { + // We've reached the end of the current statement + if (sb.length() > 0) { + statements.add(sb.toString()); + sb = new StringBuilder(); + } + i += separator.length() - 1; + continue; + } + else if (startsWithAny(script, commentPrefixes, i)) { + // Skip over any content from the start of the comment to the EOL + int indexOfNextNewline = script.indexOf('\n', i); + if (indexOfNextNewline > i) { + i = indexOfNextNewline; + continue; + } + else { + // If there's no EOL, we must be at the end of the script, so stop here. + break; + } + } + else if (script.startsWith(blockCommentStartDelimiter, i)) { + // Skip over any block comments + int indexOfCommentEnd = script.indexOf(blockCommentEndDelimiter, i); + if (indexOfCommentEnd > i) { + i = indexOfCommentEnd + blockCommentEndDelimiter.length() - 1; + continue; + } + else { + throw new ScriptParseException( + "Missing block comment end delimiter: " + blockCommentEndDelimiter, resource); + } + } + else if (c == ' ' || c == '\r' || c == '\n' || c == '\t') { + // Avoid multiple adjacent whitespace characters + if (sb.length() > 0 && sb.charAt(sb.length() - 1) != ' ') { + c = ' '; + } + else { + continue; + } + } + } + sb.append(c); + } + + if (StringUtils.hasText(sb)) { + statements.add(sb.toString()); + } + } + + /** + * Read a script from the given resource, using "{@code --}" as the comment prefix + * and "{@code ;}" as the statement separator, and build a String containing the lines. + * @param resource the {@code EncodedResource} to be read + * @return {@code String} containing the script lines + */ + public static Mono readScript(EncodedResource resource, DataBufferFactory dataBufferFactory) { + return readScript(resource, dataBufferFactory, DEFAULT_COMMENT_PREFIXES, DEFAULT_STATEMENT_SEPARATOR, + DEFAULT_BLOCK_COMMENT_END_DELIMITER); + } + + /** + * Read a script from the provided resource, using the supplied comment prefixes + * and statement separator, and build a {@code String} containing the lines. + *

Lines beginning with one of the comment prefixes are excluded + * from the results; however, line comments anywhere else — for example, + * within a statement — will be included in the results. + * @param resource the {@code EncodedResource} containing the script + * to be processed + * @param commentPrefixes the prefixes that identify comments in the SQL script + * (typically "--") + * @param separator the statement separator in the SQL script (typically ";") + * @param blockCommentEndDelimiter the end block comment delimiter + * @return a {@link Mono} of {@link String} containing the script lines that + * completes once the resource was loaded + */ + private static Mono readScript(EncodedResource resource, DataBufferFactory dataBufferFactory, + @Nullable String[] commentPrefixes, @Nullable String separator, @Nullable String blockCommentEndDelimiter) { + + return DataBufferUtils.join(DataBufferUtils.read(resource.getResource(), dataBufferFactory, 8192)) + .handle((it, sink) -> { + + try (InputStream is = it.asInputStream()) { + + InputStreamReader in = resource.getCharset() != null ? new InputStreamReader(is, resource.getCharset()) + : new InputStreamReader(is); + LineNumberReader lnr = new LineNumberReader(in); + String script = readScript(lnr, commentPrefixes, separator, blockCommentEndDelimiter); + + sink.next(script); + sink.complete(); + } + catch (Exception ex) { + sink.error(ex); + } + finally { + DataBufferUtils.release(it); + } + }); + } + + /** + * Read a script from the provided {@code LineNumberReader}, using the supplied + * comment prefix and statement separator, and build a {@code String} containing + * the lines. + *

Lines beginning with the comment prefix are excluded from the + * results; however, line comments anywhere else — for example, within + * a statement — will be included in the results. + * @param lineNumberReader the {@code LineNumberReader} containing the script + * to be processed + * @param lineCommentPrefix the prefix that identifies comments in the SQL script + * (typically "--") + * @param separator the statement separator in the SQL script (typically ";") + * @param blockCommentEndDelimiter the end block comment delimiter + * @return a {@code String} containing the script lines + * @throws IOException in case of I/O errors + */ + public static String readScript(LineNumberReader lineNumberReader, @Nullable String lineCommentPrefix, + @Nullable String separator, @Nullable String blockCommentEndDelimiter) throws IOException { + String[] lineCommentPrefixes = (lineCommentPrefix != null) ? new String[] { lineCommentPrefix } : null; + return readScript(lineNumberReader, lineCommentPrefixes, separator, blockCommentEndDelimiter); + } + + /** + * Read a script from the provided {@code LineNumberReader}, using the supplied + * comment prefixes and statement separator, and build a {@code String} containing + * the lines. + *

Lines beginning with one of the comment prefixes are excluded + * from the results; however, line comments anywhere else — for example, + * within a statement — will be included in the results. + * @param lineNumberReader the {@code LineNumberReader} containing the script + * to be processed + * @param lineCommentPrefixes the prefixes that identify comments in the SQL script + * (typically "--") + * @param separator the statement separator in the SQL script (typically ";") + * @param blockCommentEndDelimiter the end block comment delimiter + * @return a {@code String} containing the script lines + * @throws IOException in case of I/O errors + */ + public static String readScript(LineNumberReader lineNumberReader, @Nullable String[] lineCommentPrefixes, + @Nullable String separator, @Nullable String blockCommentEndDelimiter) throws IOException { + + String currentStatement = lineNumberReader.readLine(); + StringBuilder scriptBuilder = new StringBuilder(); + while (currentStatement != null) { + if ((blockCommentEndDelimiter != null && currentStatement.contains(blockCommentEndDelimiter)) || + (lineCommentPrefixes != null && !startsWithAny(currentStatement, lineCommentPrefixes, 0))) { + if (scriptBuilder.length() > 0) { + scriptBuilder.append('\n'); + } + scriptBuilder.append(currentStatement); + } + currentStatement = lineNumberReader.readLine(); + } + appendSeparatorToScriptIfNecessary(scriptBuilder, separator); + return scriptBuilder.toString(); + } + + private static void appendSeparatorToScriptIfNecessary(StringBuilder scriptBuilder, @Nullable String separator) { + if (separator == null) { + return; + } + String trimmed = separator.trim(); + if (trimmed.length() == separator.length()) { + return; + } + // separator ends in whitespace, so we might want to see if the script is trying + // to end the same way + if (scriptBuilder.lastIndexOf(trimmed) == scriptBuilder.length() - trimmed.length()) { + scriptBuilder.append(separator.substring(trimmed.length())); + } + } + + private static boolean startsWithAny(String script, String[] prefixes, int offset) { + for (String prefix : prefixes) { + if (script.startsWith(prefix, offset)) { + return true; + } + } + return false; + } + + /** + * Does the provided SQL script contain the specified delimiter? + * @param script the SQL script + * @param delim the string delimiting each statement - typically a ';' character + */ + public static boolean containsSqlScriptDelimiters(String script, String delim) { + boolean inLiteral = false; + boolean inEscape = false; + + for (int i = 0; i < script.length(); i++) { + char c = script.charAt(i); + if (inEscape) { + inEscape = false; + continue; + } + // MySQL style escapes + if (c == '\\') { + inEscape = true; + continue; + } + if (c == '\'') { + inLiteral = !inLiteral; + } + if (!inLiteral && script.startsWith(delim, i)) { + return true; + } + } + + return false; + } + + /** + * Execute the given SQL script using default settings for statement + * separators, comment delimiters, and exception handling flags. + *

Statement separators and comments will be removed before executing + * individual statements within the supplied script. + *

Warning: this method does not release the + * provided {@link Connection}. + * @param connection the R2DBC connection to use to execute the script; already + * configured and ready to use + * @param resource the resource to load the SQL script from; encoded with the + * current platform's default encoding + * @throws ScriptException if an error occurred while executing the SQL script + * @see #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String) + * @see #DEFAULT_STATEMENT_SEPARATOR + * @see #DEFAULT_COMMENT_PREFIX + * @see #DEFAULT_BLOCK_COMMENT_START_DELIMITER + * @see #DEFAULT_BLOCK_COMMENT_END_DELIMITER + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection + */ + public static Mono executeSqlScript(Connection connection, Resource resource) throws ScriptException { + return executeSqlScript(connection, new EncodedResource(resource)); + } + + /** + * Execute the given SQL script using default settings for statement + * separators, comment delimiters, and exception handling flags. + *

Statement separators and comments will be removed before executing + * individual statements within the supplied script. + *

Warning: this method does not release the + * provided {@link Connection}. + * @param connection the R2DBC connection to use to execute the script; already + * configured and ready to use + * @param resource the resource (potentially associated with a specific encoding) + * to load the SQL script from + * @throws ScriptException if an error occurred while executing the SQL script + * @see #executeSqlScript(Connection, EncodedResource, DataBufferFactory, boolean, boolean, String[], String, String, String) + * @see #DEFAULT_STATEMENT_SEPARATOR + * @see #DEFAULT_COMMENT_PREFIX + * @see #DEFAULT_BLOCK_COMMENT_START_DELIMITER + * @see #DEFAULT_BLOCK_COMMENT_END_DELIMITER + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection + */ + public static Mono executeSqlScript(Connection connection, EncodedResource resource) throws ScriptException { + return executeSqlScript(connection, resource, new DefaultDataBufferFactory(), false, false, DEFAULT_COMMENT_PREFIX, + DEFAULT_STATEMENT_SEPARATOR, DEFAULT_BLOCK_COMMENT_START_DELIMITER, DEFAULT_BLOCK_COMMENT_END_DELIMITER); + } + + /** + * Execute the given SQL script. + *

Statement separators and comments will be removed before executing + * individual statements within the supplied script. + *

Warning: this method does not release the + * provided {@link Connection}. + * @param connection the R2DBC connection to use to execute the script; already + * configured and ready to use + * @param resource the resource (potentially associated with a specific encoding) + * to load the SQL script from + * @param continueOnError whether or not to continue without throwing an exception + * in the event of an error + * @param ignoreFailedDrops whether or not to continue in the event of specifically + * an error on a {@code DROP} statement + * @param commentPrefix the prefix that identifies single-line comments in the + * SQL script (typically "--") + * @param separator the script statement separator; defaults to + * {@value #DEFAULT_STATEMENT_SEPARATOR} if not specified and falls back to + * {@value #FALLBACK_STATEMENT_SEPARATOR} as a last resort; may be set to + * {@value #EOF_STATEMENT_SEPARATOR} to signal that the script contains a + * single statement without a separator + * @param blockCommentStartDelimiter the start block comment delimiter + * @param blockCommentEndDelimiter the end block comment delimiter + * @throws ScriptException if an error occurred while executing the SQL script + * @see #DEFAULT_STATEMENT_SEPARATOR + * @see #FALLBACK_STATEMENT_SEPARATOR + * @see #EOF_STATEMENT_SEPARATOR + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection + */ + public static Mono executeSqlScript(Connection connection, EncodedResource resource, + DataBufferFactory dataBufferFactory, boolean continueOnError, boolean ignoreFailedDrops, String commentPrefix, + @Nullable String separator, String blockCommentStartDelimiter, String blockCommentEndDelimiter) + throws ScriptException { + + return executeSqlScript(connection, resource, dataBufferFactory, continueOnError, + ignoreFailedDrops, new String[] { commentPrefix }, separator, + blockCommentStartDelimiter, blockCommentEndDelimiter); + } + + /** + * Execute the given SQL script. + *

Statement separators and comments will be removed before executing + * individual statements within the supplied script. + *

Warning: this method does not release the + * provided {@link Connection}. + * @param connection the R2DBC connection to use to execute the script; already + * configured and ready to use + * @param resource the resource (potentially associated with a specific encoding) + * to load the SQL script from + * @param continueOnError whether or not to continue without throwing an exception + * in the event of an error + * @param ignoreFailedDrops whether or not to continue in the event of specifically + * an error on a {@code DROP} statement + * @param commentPrefixes the prefixes that identify single-line comments in the + * SQL script (typically "--") + * @param separator the script statement separator; defaults to + * {@value #DEFAULT_STATEMENT_SEPARATOR} if not specified and falls back to + * {@value #FALLBACK_STATEMENT_SEPARATOR} as a last resort; may be set to + * {@value #EOF_STATEMENT_SEPARATOR} to signal that the script contains a + * single statement without a separator + * @param blockCommentStartDelimiter the start block comment delimiter + * @param blockCommentEndDelimiter the end block comment delimiter + * @throws ScriptException if an error occurred while executing the SQL script + * @see #DEFAULT_STATEMENT_SEPARATOR + * @see #FALLBACK_STATEMENT_SEPARATOR + * @see #EOF_STATEMENT_SEPARATOR + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#getConnection + * @see org.springframework.r2dbc.connection.ConnectionFactoryUtils#releaseConnection + */ + public static Mono executeSqlScript(Connection connection, EncodedResource resource, DataBufferFactory dataBufferFactory, + boolean continueOnError, + boolean ignoreFailedDrops, String[] commentPrefixes, @Nullable String separator, + String blockCommentStartDelimiter, String blockCommentEndDelimiter) throws ScriptException { + + if (logger.isDebugEnabled()) { + logger.debug("Executing SQL script from " + resource); + } + + long startTime = System.currentTimeMillis(); + + Mono script = readScript(resource, dataBufferFactory, commentPrefixes, separator, blockCommentEndDelimiter) + .onErrorMap(IOException.class, ex -> new CannotReadScriptException(resource, ex)); + + AtomicInteger statementNumber = new AtomicInteger(); + + Flux executeScript = script.flatMapIterable(statement -> { + List statements = new ArrayList<>(); + String separatorToUse = separator; + if (separatorToUse == null) { + separatorToUse = DEFAULT_STATEMENT_SEPARATOR; + } + if (!EOF_STATEMENT_SEPARATOR.equals(separatorToUse) && !containsSqlScriptDelimiters(statement, separatorToUse)) { + separatorToUse = FALLBACK_STATEMENT_SEPARATOR; + } + splitSqlScript(resource, statement, separatorToUse, commentPrefixes, blockCommentStartDelimiter, + blockCommentEndDelimiter, statements); + return statements; + }).concatMap(statement -> { + + statementNumber.incrementAndGet(); + return runStatement(statement, connection, resource, continueOnError, ignoreFailedDrops, statementNumber); + }); + + if (logger.isDebugEnabled()) { + + executeScript = executeScript.doOnComplete(() -> { + + long elapsedTime = System.currentTimeMillis() - startTime; + logger.debug("Executed SQL script from " + resource + " in " + elapsedTime + " ms."); + }); + } + + return executeScript.onErrorMap(ex -> !(ex instanceof ScriptException), + ex -> new UncategorizedScriptException("Failed to execute database script from resource [" + resource + "]", + ex)) + .then(); + } + + private static Publisher runStatement(String statement, Connection connection, + EncodedResource resource, boolean continueOnError, boolean ignoreFailedDrops, AtomicInteger statementNumber) { + + Mono execution = Flux.from(connection.createStatement(statement).execute()) + .flatMap(Result::getRowsUpdated) + .collect(Collectors.summingLong(count -> count)); + + if (logger.isDebugEnabled()) { + execution = execution.doOnNext(rowsAffected -> logger.debug(rowsAffected + " returned as update count for SQL: " + statement)); + } + + return execution.onErrorResume(ex -> { + + boolean dropStatement = StringUtils.startsWithIgnoreCase(statement.trim(), "drop"); + if (continueOnError || (dropStatement && ignoreFailedDrops)) { + if (logger.isDebugEnabled()) { + logger.debug(ScriptStatementFailedException.buildErrorMessage(statement, statementNumber.get(), resource), + ex); + } + } + else { + return Mono.error(new ScriptStatementFailedException(statement, statementNumber.get(), resource, ex)); + } + + return Mono.empty(); + }).then(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/UncategorizedScriptException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/UncategorizedScriptException.java new file mode 100644 index 000000000000..bd0eb04d9da3 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/UncategorizedScriptException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +/** + * Thrown when we cannot determine anything more specific than "something went wrong while + * processing an SQL script": for example, a {@link io.r2dbc.spi.R2dbcException} from + * R2DBC that we cannot pinpoint more precisely. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class UncategorizedScriptException extends ScriptException { + + /** + * Create a new {@code UncategorizedScriptException}. + * @param message detailed message + */ + public UncategorizedScriptException(String message) { + super(message); + } + + /** + * Create a new {@code UncategorizedScriptException}. + * @param message detailed message + * @param cause the root cause + */ + public UncategorizedScriptException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/package-info.java new file mode 100644 index 000000000000..3153ea5232c3 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/init/package-info.java @@ -0,0 +1,6 @@ +/** + * Provides extensible support for initializing databases through scripts. + */ +@org.springframework.lang.NonNullApi +@org.springframework.lang.NonNullFields +package org.springframework.r2dbc.connection.init; diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactory.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactory.java new file mode 100644 index 000000000000..2649e179b975 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactory.java @@ -0,0 +1,243 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import java.util.HashMap; +import java.util.Map; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import reactor.core.publisher.Mono; + +import org.springframework.beans.factory.InitializingBean; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Abstract {@link ConnectionFactory} implementation that routes + * {@link #create()} calls to one of various target + * {@link ConnectionFactory factories} based on a lookup key. + * The latter is typically (but not necessarily) determined from some + * subscriber context. + * + *

Allows to configure a {@link #setDefaultTargetConnectionFactory(Object) + * default ConnectionFactory} as fallback. + * + *

Calls to {@link #getMetadata()} are routed to the + * {@link #setDefaultTargetConnectionFactory(Object) default ConnectionFactory} + * if configured. + * + * @author Mark Paluch + * @author Jens Schauder + * @since 5.3 + * @see #setTargetConnectionFactories + * @see #setDefaultTargetConnectionFactory + * @see #determineCurrentLookupKey() + */ +public abstract class AbstractRoutingConnectionFactory implements ConnectionFactory, InitializingBean { + + private static final Object FALLBACK_MARKER = new Object(); + + + @Nullable + private Map targetConnectionFactories; + + @Nullable + private Object defaultTargetConnectionFactory; + + private boolean lenientFallback = true; + + private ConnectionFactoryLookup connectionFactoryLookup = new MapConnectionFactoryLookup(); + + @Nullable + private Map resolvedConnectionFactories; + + @Nullable + private ConnectionFactory resolvedDefaultConnectionFactory; + + + /** + * Specify the map of target {@link ConnectionFactory ConnectionFactories}, + * with the lookup key as key. The mapped value can either be a corresponding + * {@link ConnectionFactory} instance or a connection factory name String (to be + * resolved via a {@link #setConnectionFactoryLookup ConnectionFactoryLookup}). + * + *

The key can be of arbitrary type; this class implements the generic lookup + * process only. The concrete key representation will be handled by + * {@link #resolveSpecifiedLookupKey(Object)} and {@link #determineCurrentLookupKey()}. + */ + public void setTargetConnectionFactories(Map targetConnectionFactories) { + this.targetConnectionFactories = targetConnectionFactories; + } + + /** + * Specify the default target {@link ConnectionFactory}, if any. + * + *

The mapped value can either be a corresponding {@link ConnectionFactory} + * instance or a connection factory name {@link String} (to be resolved via a + * {@link #setConnectionFactoryLookup ConnectionFactoryLookup}). + * + *

This {@link ConnectionFactory} will be used as target if none of the keyed + * {@link #setTargetConnectionFactories targetConnectionFactories} match the + * {@link #determineCurrentLookupKey() current lookup key}. + */ + public void setDefaultTargetConnectionFactory(Object defaultTargetConnectionFactory) { + this.defaultTargetConnectionFactory = defaultTargetConnectionFactory; + } + + /** + * Specify whether to apply a lenient fallback to the default {@link ConnectionFactory} + * if no specific {@link ConnectionFactory} could be found for the current lookup key. + * + *

Default is {@code true}, accepting lookup keys without a corresponding entry + * in the target {@link ConnectionFactory} map - simply falling back to the default + * {@link ConnectionFactory} in that case. + * + *

Switch this flag to {@code false} if you would prefer the fallback to only + * apply when no lookup key was emitted. Lookup keys without a {@link ConnectionFactory} + * entry will then lead to an {@link IllegalStateException}. + * @see #setTargetConnectionFactories + * @see #setDefaultTargetConnectionFactory + * @see #determineCurrentLookupKey() + */ + public void setLenientFallback(boolean lenientFallback) { + this.lenientFallback = lenientFallback; + } + + /** + * Set the {@link ConnectionFactoryLookup} implementation to use for resolving + * connection factory name Strings in the {@link #setTargetConnectionFactories + * targetConnectionFactories} map. + */ + public void setConnectionFactoryLookup(ConnectionFactoryLookup connectionFactoryLookup) { + Assert.notNull(connectionFactoryLookup, "ConnectionFactoryLookup must not be null"); + this.connectionFactoryLookup = connectionFactoryLookup; + } + + @Override + public void afterPropertiesSet() { + + Assert.notNull(this.targetConnectionFactories, "Property 'targetConnectionFactories' must not be null"); + + this.resolvedConnectionFactories = new HashMap<>(this.targetConnectionFactories.size()); + this.targetConnectionFactories.forEach((key, value) -> { + Object lookupKey = resolveSpecifiedLookupKey(key); + ConnectionFactory connectionFactory = resolveSpecifiedConnectionFactory(value); + this.resolvedConnectionFactories.put(lookupKey, connectionFactory); + }); + + if (this.defaultTargetConnectionFactory != null) { + this.resolvedDefaultConnectionFactory = resolveSpecifiedConnectionFactory(this.defaultTargetConnectionFactory); + } + } + + /** + * Resolve the given lookup key object, as specified in the + * {@link #setTargetConnectionFactories targetConnectionFactories} map, + * into the actual lookup key to be used for matching with the + * {@link #determineCurrentLookupKey() current lookup key}. + *

The default implementation simply returns the given key as-is. + * @param lookupKey the lookup key object as specified by the user + * @return the lookup key as needed for matching. + */ + protected Object resolveSpecifiedLookupKey(Object lookupKey) { + return lookupKey; + } + + /** + * Resolve the specified connection factory object into a + * {@link ConnectionFactory} instance. + *

The default implementation handles {@link ConnectionFactory} instances + * and connection factory names (to be resolved via a + * {@link #setConnectionFactoryLookup ConnectionFactoryLookup}). + * @param connectionFactory the connection factory value object as specified in the + * {@link #setTargetConnectionFactories targetConnectionFactories} map + * @return the resolved {@link ConnectionFactory} (never {@code null}) + * @throws IllegalArgumentException in case of an unsupported value type + */ + protected ConnectionFactory resolveSpecifiedConnectionFactory(Object connectionFactory) + throws IllegalArgumentException { + if (connectionFactory instanceof ConnectionFactory) { + return (ConnectionFactory) connectionFactory; + } + else if (connectionFactory instanceof String) { + return this.connectionFactoryLookup.getConnectionFactory((String) connectionFactory); + } + else { + throw new IllegalArgumentException( + "Illegal connection factory value - only 'io.r2dbc.spi.ConnectionFactory' and 'String' supported: " + + connectionFactory); + } + } + + @Override + public Mono create() { + return determineTargetConnectionFactory() // + .map(ConnectionFactory::create) // + .flatMap(Mono::from); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + if (this.resolvedDefaultConnectionFactory != null) { + return this.resolvedDefaultConnectionFactory.getMetadata(); + } + throw new UnsupportedOperationException( + "No default ConnectionFactory configured to retrieve ConnectionFactoryMetadata"); + } + + /** + * Retrieve the current target {@link ConnectionFactory}. Determines the + * {@link #determineCurrentLookupKey() current lookup key}, performs a lookup + * in the {@link #setTargetConnectionFactories targetConnectionFactories} map, + * falls back to the specified {@link #setDefaultTargetConnectionFactory default + * target ConnectionFactory} if necessary. + * @return {@link Mono} emitting the current {@link ConnectionFactory} as + * per {@link #determineCurrentLookupKey()} + * @see #determineCurrentLookupKey() + */ + protected Mono determineTargetConnectionFactory() { + Assert.state(this.resolvedConnectionFactories != null, "ConnectionFactory router not initialized"); + + Mono lookupKey = determineCurrentLookupKey().defaultIfEmpty(FALLBACK_MARKER); + + return lookupKey.handle((key, sink) -> { + ConnectionFactory connectionFactory = this.resolvedConnectionFactories.get(key); + if (connectionFactory == null && (key == FALLBACK_MARKER || this.lenientFallback)) { + connectionFactory = this.resolvedDefaultConnectionFactory; + } + if (connectionFactory == null) { + sink.error(new IllegalStateException(String.format( + "Cannot determine target ConnectionFactory for lookup key '%s'", key == FALLBACK_MARKER ? null : key))); + return; + } + sink.next(connectionFactory); + }); + } + + /** + * Determine the current lookup key. This will typically be implemented to check a + * subscriber context. Allows for arbitrary keys. The returned key needs to match the + * stored lookup key type, as resolved by the {@link #resolveSpecifiedLookupKey} method. + * + * @return {@link Mono} emitting the lookup key. May complete without emitting a value + * if no lookup key available + */ + protected abstract Mono determineCurrentLookupKey(); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookup.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookup.java new file mode 100644 index 000000000000..1e19d3eb0b92 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookup.java @@ -0,0 +1,85 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link ConnectionFactoryLookup} implementation based on a + * Spring {@link BeanFactory}. + * + *

Will lookup Spring managed beans identified by bean name, + * expecting them to be of type {@link ConnectionFactory}. + * + * @author Mark Paluch + * @since 5.3 + * @see BeanFactory + */ +public class BeanFactoryConnectionFactoryLookup implements ConnectionFactoryLookup, BeanFactoryAware { + + @Nullable + private BeanFactory beanFactory; + + + /** + * Create a new instance of the {@link BeanFactoryConnectionFactoryLookup} class. + *

The BeanFactory to access must be set via {@code setBeanFactory}. + * @see #setBeanFactory + */ + public BeanFactoryConnectionFactoryLookup() {} + + /** + * Create a new instance of the {@link BeanFactoryConnectionFactoryLookup} class. + *

Use of this constructor is redundant if this object is being created + * by a Spring IoC container, as the supplied {@link BeanFactory} will be + * replaced by the {@link BeanFactory} that creates it (c.f. the + * {@link BeanFactoryAware} contract). So only use this constructor if you + * are using this class outside the context of a Spring IoC container. + * @param beanFactory the bean factory to be used to lookup {@link ConnectionFactory + * ConnectionFactories} + */ + public BeanFactoryConnectionFactoryLookup(BeanFactory beanFactory) { + Assert.notNull(beanFactory, "BeanFactory must not be null"); + this.beanFactory = beanFactory; + } + + + @Override + public void setBeanFactory(BeanFactory beanFactory) { + this.beanFactory = beanFactory; + } + + @Override + public ConnectionFactory getConnectionFactory(String connectionFactoryName) + throws ConnectionFactoryLookupFailureException { + Assert.state(this.beanFactory != null, "BeanFactory is required"); + try { + return this.beanFactory.getBean(connectionFactoryName, ConnectionFactory.class); + } + catch (BeansException ex) { + throw new ConnectionFactoryLookupFailureException( + String.format("Failed to look up ConnectionFactory bean with name '%s'", connectionFactoryName), ex); + } + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookup.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookup.java new file mode 100644 index 000000000000..70d5c99f1204 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookup.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.ConnectionFactory; + +/** + * Strategy interface for looking up {@link ConnectionFactory} by name. + * + * @author Mark Paluch + * @since 5.3 + */ +@FunctionalInterface +public interface ConnectionFactoryLookup { + + /** + * Retrieve the {@link ConnectionFactory} identified by the given name. + * @param connectionFactoryName the name of the {@link ConnectionFactory} + * @return the {@link ConnectionFactory} (never {@code null}) + * @throws ConnectionFactoryLookupFailureException if the lookup failed + */ + ConnectionFactory getConnectionFactory(String connectionFactoryName) throws ConnectionFactoryLookupFailureException; + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookupFailureException.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookupFailureException.java new file mode 100644 index 000000000000..718525a5ee25 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/ConnectionFactoryLookupFailureException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import org.springframework.dao.NonTransientDataAccessException; + +/** + * Exception to be thrown by a {@link ConnectionFactoryLookup} implementation, + * indicating that the specified {@link io.r2dbc.spi.ConnectionFactory} could + * not be obtained. + * + * @author Mark Paluch + * @since 5.3 + */ +@SuppressWarnings("serial") +public class ConnectionFactoryLookupFailureException extends NonTransientDataAccessException { + + /** + * Create a new {@code ConnectionFactoryLookupFailureException}. + * @param msg the detail message + */ + public ConnectionFactoryLookupFailureException(String msg) { + super(msg); + } + + /** + * Create a new {@code ConnectionFactoryLookupFailureException}. + * @param msg the detail message + * @param cause the root cause + */ + public ConnectionFactoryLookupFailureException(String msg, Throwable cause) { + super(msg, cause); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookup.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookup.java new file mode 100644 index 000000000000..0e6f0f12f579 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookup.java @@ -0,0 +1,111 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.util.Assert; + +/** + * Simple {@link ConnectionFactoryLookup} implementation that relies + * on a map for doing lookups. + * + *

Useful for testing environments or applications that need to match + * arbitrary {@link String} names to target {@link ConnectionFactory} objects. + * + * @author Mark Paluch + * @author Jens Schauder + * @since 5.3 + */ +public class MapConnectionFactoryLookup implements ConnectionFactoryLookup { + + private final Map connectionFactories = new HashMap<>(); + + + /** + * Create a new instance of the {@link MapConnectionFactoryLookup} class. + */ + public MapConnectionFactoryLookup() {} + + /** + * Create a new instance of the {@link MapConnectionFactoryLookup} class. + * @param connectionFactories the {@link Map} of {@link ConnectionFactory}. + * The keys are {@link String Strings}, the values are actual {@link ConnectionFactory} instances. + */ + public MapConnectionFactoryLookup(Map connectionFactories) { + setConnectionFactories(connectionFactories); + } + + /** + * Create a new instance of the {@link MapConnectionFactoryLookup} class. + * + * @param connectionFactoryName the name under which the supplied {@link ConnectionFactory} is to be added + * @param connectionFactory the {@link ConnectionFactory} to be added + */ + public MapConnectionFactoryLookup(String connectionFactoryName, ConnectionFactory connectionFactory) { + addConnectionFactory(connectionFactoryName, connectionFactory); + } + + + /** + * Set the {@link Map} of {@link ConnectionFactory ConnectionFactories}. + * The keys are {@link String Strings}, the values are actual {@link ConnectionFactory} instances. + *

If the supplied {@link Map} is {@code null}, then this method call effectively has no effect. + * @param connectionFactories said {@link Map} of {@link ConnectionFactory connectionFactories} + */ + public void setConnectionFactories(Map connectionFactories) { + Assert.notNull(connectionFactories, "ConnectionFactories must not be null"); + this.connectionFactories.putAll(connectionFactories); + } + + /** + * Get the {@link Map} of {@link ConnectionFactory ConnectionFactories} maintained by this object. + *

The returned {@link Map} is {@link Collections#unmodifiableMap(Map) unmodifiable}. + * @return {@link Map} of {@link ConnectionFactory connectionFactory} (never {@code null}) + */ + public Map getConnectionFactories() { + return Collections.unmodifiableMap(this.connectionFactories); + } + + /** + * Add the supplied {@link ConnectionFactory} to the map of {@link ConnectionFactory ConnectionFactorys} maintained by + * this object. + * + * @param connectionFactoryName the name under which the supplied {@link ConnectionFactory} is to be added + * @param connectionFactory the {@link ConnectionFactory} to be so added + */ + public void addConnectionFactory(String connectionFactoryName, ConnectionFactory connectionFactory) { + Assert.notNull(connectionFactoryName, "ConnectionFactory name must not be null"); + Assert.notNull(connectionFactory, "ConnectionFactory must not be null"); + this.connectionFactories.put(connectionFactoryName, connectionFactory); + } + + @Override + public ConnectionFactory getConnectionFactory(String connectionFactoryName) + throws ConnectionFactoryLookupFailureException { + Assert.notNull(connectionFactoryName, "ConnectionFactory name must not be null"); + return this.connectionFactories.computeIfAbsent(connectionFactoryName, key -> { + throw new ConnectionFactoryLookupFailureException( + "No ConnectionFactory with name '" + connectionFactoryName + "' registered"); + }); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/SingleConnectionFactoryLookup.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/SingleConnectionFactoryLookup.java new file mode 100644 index 000000000000..afda412ce825 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/SingleConnectionFactoryLookup.java @@ -0,0 +1,52 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.ConnectionFactory; + +import org.springframework.util.Assert; + +/** + * An implementation of {@link ConnectionFactoryLookup} that + * simply wraps a single given {@link ConnectionFactory} + * returned for any connection factory name. + * + * @author Mark Paluch + * @since 5.3 + */ +public class SingleConnectionFactoryLookup implements ConnectionFactoryLookup { + + private final ConnectionFactory connectionFactory; + + + /** + * Create a new instance of the {@link SingleConnectionFactoryLookup} class. + * @param connectionFactory the single {@link ConnectionFactory} to wrap + */ + public SingleConnectionFactoryLookup(ConnectionFactory connectionFactory) { + Assert.notNull(connectionFactory, "ConnectionFactory must not be null"); + this.connectionFactory = connectionFactory; + } + + + @Override + public ConnectionFactory getConnectionFactory(String connectionFactoryName) + throws ConnectionFactoryLookupFailureException { + return this.connectionFactory; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/package-info.java new file mode 100644 index 000000000000..cd99f598865d --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/lookup/package-info.java @@ -0,0 +1,9 @@ +/** + * Provides a strategy for looking up R2DBC ConnectionFactories by name. + */ +@NonNullApi +@NonNullFields +package org.springframework.r2dbc.connection.lookup; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/package-info.java new file mode 100644 index 000000000000..5da1a594ea92 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/connection/package-info.java @@ -0,0 +1,11 @@ +/** + * Provides a utility class for easy ConnectionFactory access, + * a ReactiveTransactionManager for a single ConnectionFactory, + * and various simple ConnectionFactory implementations. + */ +@NonNullApi +@NonNullFields +package org.springframework.r2dbc.connection; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/BindParameterSource.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/BindParameterSource.java new file mode 100644 index 000000000000..ebc5bb195066 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/BindParameterSource.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import org.springframework.lang.Nullable; + +/** + * Interface that defines common functionality for objects + * that can offer parameter values for named bind parameters, + * serving as argument for {@link NamedParameterExpander} operations. + * + *

This interface allows for the specification of the type in + * addition to parameter values. All parameter values and types are + * identified by specifying the name of the parameter. + * + *

Intended to wrap various implementations like a {@link java.util.Map} + * with a consistent interface. + * + * @author Mark Paluch + * @since 5.3 + * @see MapBindParameterSource + */ +interface BindParameterSource { + + /** + * Determine whether there is a value for the specified named parameter. + * @param paramName the name of the parameter + * @return {@code true} if there is a value defined; {@code false} otherwise + */ + boolean hasValue(String paramName); + + /** + * Return the parameter value for the requested named parameter. + * @param paramName the name of the parameter + * @return the value of the specified parameter, can be {@code null} + * @throws IllegalArgumentException if there is no value + * for the requested parameter + */ + @Nullable + Object getValue(String paramName) throws IllegalArgumentException; + + /** + * Determine the type for the specified named parameter. + * @param paramName the name of the parameter + * @return the type of the specified parameter, or + * {@link Object#getClass()} if not known. + */ + default Class getType(String paramName) { + return Object.class; + } + + /** + * Return parameter names of the underlying parameter source. + * @return parameter names of the underlying parameter source. + */ + Iterable getParameterNames(); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ColumnMapRowMapper.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ColumnMapRowMapper.java new file mode 100644 index 000000000000..5067a2bc6c70 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ColumnMapRowMapper.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.Collection; +import java.util.Map; +import java.util.function.BiFunction; + +import io.r2dbc.spi.ColumnMetadata; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; + +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedCaseInsensitiveMap; + +/** + * {@link BiFunction Mapping function} implementation that creates a + * {@code java.util.Map} for each row, representing all columns as + * key-value pairs: one entry for each column, with the column name as key. + * + *

The Map implementation to use and the key to use for each column + * in the column Map can be customized through overriding + * {@link #createColumnMap} and {@link #getColumnKey}, respectively. + * + *

Note: By default, ColumnMapRowMapper will try to build a linked Map + * with case-insensitive keys, to preserve column order as well as allow any + * casing to be used for column names. This requires Commons Collections on the + * classpath (which will be autodetected). Else, the fallback is a standard linked + * HashMap, which will still preserve column order but requires the application + * to specify the column names in the same casing as exposed by the driver. + * + * @author Mark Paluch + * @since 5.3 + */ +public class ColumnMapRowMapper implements BiFunction> { + + /** Default instance. */ + public final static ColumnMapRowMapper INSTANCE = new ColumnMapRowMapper(); + + + @Override + public Map apply(Row row, RowMetadata rowMetadata) { + Collection columns = rowMetadata.getColumnNames(); + int columnCount = columns.size(); + Map mapOfColValues = createColumnMap(columnCount); + int index = 0; + for (String column : columns) { + String key = getColumnKey(column); + Object obj = getColumnValue(row, index++); + mapOfColValues.put(key, obj); + } + return mapOfColValues; + } + + /** + * Create a {@link Map} instance to be used as column map. + *

By default, a linked case-insensitive Map will be created. + * @param columnCount the column count, to be used as initial capacity for the Map + * @return the new {@link Map} instance + * @see LinkedCaseInsensitiveMap + */ + protected Map createColumnMap(int columnCount) { + return new LinkedCaseInsensitiveMap<>(columnCount); + } + + /** + * Determine the key to use for the given column in the column {@link Map}. + * @param columnName the column name as returned by the {@link Row} + * @return the column key to use + * @see ColumnMetadata#getName() + */ + protected String getColumnKey(String columnName) { + return columnName; + } + + /** + * Retrieve a R2DBC object value for the specified column. + *

The default implementation uses the {@link Row#get(int)} method. + * @param row is the {@link Row} holding the data + * @param index is the column index + * @return the Object returned + */ + @Nullable + protected Object getColumnValue(Row row, int index) { + return row.get(index); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionAccessor.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionAccessor.java new file mode 100644 index 000000000000..bfeb09dd1bda --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionAccessor.java @@ -0,0 +1,66 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Function; + +import io.r2dbc.spi.Connection; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.dao.DataAccessException; + +/** + * Interface declaring methods that accept callback {@link Function} + * to operate within the scope of a {@link Connection}. + * Callback functions operate on a provided connection and must not + * close the connection as the connections may be pooled or be + * subject to other kinds of resource management. + * + *

Callback functions are responsible for creating a + * {@link org.reactivestreams.Publisher} that defines the scope of how + * long the allocated {@link Connection} is valid. Connections are + * released after the publisher terminates. + * + * @author Mark Paluch + * @since 5.3 + */ +public interface ConnectionAccessor { + + /** + * Execute a callback {@link Function} within a {@link Connection} scope. + * The function is responsible for creating a {@link Mono}. The connection + * is released after the {@link Mono} terminates (or the subscription + * is cancelled). Connection resources must not be passed outside of the + * {@link Function} closure, otherwise resources may get defunct. + * @param action the callback object that specifies the connection action + * @return the resulting {@link Mono} + */ + Mono inConnection(Function> action) throws DataAccessException; + + /** + * Execute a callback {@link Function} within a {@link Connection} scope. + * The function is responsible for creating a {@link Flux}. The connection + * is released after the {@link Flux} terminates (or the subscription + * is cancelled). Connection resources must not be passed outside of the + * {@link Function} closure, otherwise resources may get defunct. + * @param action the callback object that specifies the connection action + * @return the resulting {@link Flux} + */ + Flux inConnectionMany(Function> action) throws DataAccessException; + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java new file mode 100644 index 000000000000..4d1b26529474 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ConnectionFunction.java @@ -0,0 +1,53 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Function; + +import io.r2dbc.spi.Connection; + +/** + * Union type combining {@link Function} and {@link SqlProvider} to expose the SQL that is + * related to the underlying action. + * + * @author Mark Paluch + * @since 5.3 + * @param the type of the result of the function. + */ +class ConnectionFunction implements Function, SqlProvider { + + private final String sql; + + private final Function function; + + + ConnectionFunction(String sql, Function function) { + this.sql = sql; + this.function = function; + } + + + @Override + public R apply(Connection t) { + return this.function.apply(t); + } + + @Override + public String getSql() { + return this.sql; + } +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java new file mode 100644 index 000000000000..02f7ed352d5f --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DatabaseClient.java @@ -0,0 +1,250 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import io.r2dbc.spi.Statement; +import reactor.core.publisher.Mono; + +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.util.Assert; + +/** + * A non-blocking, reactive client for performing database calls requests with + * Reactive Streams back pressure. Provides a higher level, common API over + * R2DBC client libraries. + * + *

Use one of the static factory methods {@link #create(ConnectionFactory)} + * or obtain a {@link DatabaseClient#builder()} to create an instance. + * + * Usage example: + *

+ * ConnectionFactory factory = …
+ *
+ * DatabaseClient client = DatabaseClient.create(factory);
+ * Mono>Actor;lt actor = client.sql("select first_name, last_name from t_actor")
+ *     .map(row -> new Actor(row.get("first_name, String.class"),
+ *     row.get("last_name, String.class")))
+ *     .first();
+ * 
+ * + * @author Mark Paluch + * @since 5.3 + */ +public interface DatabaseClient extends ConnectionAccessor { + + /** + * Specify a static {@code sql} statement to run. Contract for specifying a + * SQL call along with options leading to the execution. The SQL string can + * contain either native parameter bind markers or named parameters (e.g. + * {@literal :foo, :bar}) when {@link NamedParameterExpander} is enabled. + * + * @param sql must not be {@code null} or empty + * @return a new {@link GenericExecuteSpec} + * @see NamedParameterExpander + * @see DatabaseClient.Builder#namedParameters(boolean) + */ + GenericExecuteSpec sql(String sql); + + /** + * Specify a {@link Supplier SQL supplier} that provides SQL to run. + * Contract for specifying a SQL call along with options leading to + * the execution. The SQL string can contain either native parameter + * bind markers or named parameters (e.g. {@literal :foo, :bar}) when + * {@link NamedParameterExpander} is enabled. + * + *

Accepts {@link PreparedOperation} as SQL and binding {@link Supplier} + * @param sqlSupplier must not be {@code null} + * @return a new {@link GenericExecuteSpec} + * @see NamedParameterExpander + * @see DatabaseClient.Builder#namedParameters(boolean) + * @see PreparedOperation + */ + GenericExecuteSpec sql(Supplier sqlSupplier); + + + // Static, factory methods + + /** + * Create a {@code DatabaseClient} that will use the provided {@link ConnectionFactory}. + * @param factory the {@code ConnectionFactory} to use for obtaining connections + * @return a new {@code DatabaseClient}. Guaranteed to be not {@code null}. + */ + static DatabaseClient create(ConnectionFactory factory) { + return new DefaultDatabaseClientBuilder().connectionFactory(factory).build(); + } + + /** + * Obtain a {@code DatabaseClient} builder. + */ + static DatabaseClient.Builder builder() { + return new DefaultDatabaseClientBuilder(); + } + + + /** + * A mutable builder for creating a {@link DatabaseClient}. + */ + interface Builder { + + /** + * Configure the {@link BindMarkersFactory BindMarkers} to be used. + * @param bindMarkers must not be {@code null} + */ + Builder bindMarkers(BindMarkersFactory bindMarkers); + + /** + * Configure the {@link ConnectionFactory R2DBC connector}. + * @param factory must not be {@code null} + */ + Builder connectionFactory(ConnectionFactory factory); + + /** + * Configure a {@link ExecuteFunction} to execute {@link Statement} objects. + * @param executeFunction must not be {@code null} + * @see Statement#execute() + */ + Builder executeFunction(ExecuteFunction executeFunction); + + /** + * Configure whether to use named parameter expansion. Defaults to {@code true}. + * @param enabled {@code true} to use named parameter expansion. + * {@code false} to disable named parameter expansion. + * @see NamedParameterExpander + */ + Builder namedParameters(boolean enabled); + + /** + * Configures a {@link Consumer} to configure this builder. + * @param builderConsumer must not be {@code null}. + */ + Builder apply(Consumer builderConsumer); + + /** + * Builder the {@link DatabaseClient} instance. + */ + DatabaseClient build(); + + } + + + /** + * Contract for specifying a SQL call along with options leading to the execution. + */ + interface GenericExecuteSpec { + + /** + * Bind a non-{@code null} value to a parameter identified by its + * {@code index}. {@code value} can be either a scalar value or {@link Parameter}. + * @param index zero based index to bind the parameter to + * @param value must not be {@code null}. Can be either a scalar value or {@link Parameter} + */ + GenericExecuteSpec bind(int index, Object value); + + /** + * Bind a {@code null} value to a parameter identified by its {@code index}. + * @param index zero based index to bind the parameter to + * @param type must not be {@code null} + */ + GenericExecuteSpec bindNull(int index, Class type); + + /** + * Bind a non-{@code null} value to a parameter identified by its {@code name}. + * @param name must not be {@code null} or empty + * @param value must not be {@code null} + */ + GenericExecuteSpec bind(String name, Object value); + + /** + * Bind a {@code null} value to a parameter identified by its {@code name}. + * @param name must not be {@code null} or empty + * @param type must not be {@code null} + */ + GenericExecuteSpec bindNull(String name, Class type); + + /** + * Add the given filter to the end of the filter chain. + *

Filter functions are typically used to invoke methods on the Statement + * before it is executed. + * + * For example: + *

+		 * DatabaseClient client = …;
+		 * client.sql("SELECT book_id FROM book").filter(statement -> statement.fetchSize(100))
+		 * 
+ * @param filter the filter to be added to the chain + */ + default GenericExecuteSpec filter(Function filter) { + Assert.notNull(filter, "Statement FilterFunction must not be null"); + return filter((statement, next) -> next.execute(filter.apply(statement))); + } + + /** + * Add the given filter to the end of the filter chain. + *

Filter functions are typically used to invoke methods on the Statement + * before it is executed. + * + * For example: + *

+		 * DatabaseClient client = …;
+		 * client.sql("SELECT book_id FROM book").filter((statement, next) -> next.execute(statement.fetchSize(100)))
+		 * 
+ * @param filter the filter to be added to the chain + */ + GenericExecuteSpec filter(StatementFilterFunction filter); + + /** + * Configure a result mapping {@link Function function} and enter the execution stage. + * @param mappingFunction must not be {@code null} + * @param result type. + * @return a {@link FetchSpec} for configuration what to fetch. Guaranteed to be not {@code null}. + */ + default RowsFetchSpec map(Function mappingFunction) { + Assert.notNull(mappingFunction, "Mapping function must not be null"); + return map((row, rowMetadata) -> mappingFunction.apply(row)); + } + + /** + * Configure a result mapping {@link BiFunction function} and enter the execution stage. + * @param mappingFunction must not be {@code null} + * @param result type. + * @return a {@link FetchSpec} for configuration what to fetch. Guaranteed to be not {@code null}. + */ + RowsFetchSpec map(BiFunction mappingFunction); + + /** + * Perform the SQL call and retrieve the result by entering the execution stage. + */ + FetchSpec> fetch(); + + /** + * Perform the SQL call and return a {@link Mono} that completes without result on statement completion. + * @return a {@link Mono} ignoring its payload (actively dropping). + */ + Mono then(); + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java new file mode 100644 index 000000000000..78f79d7277d4 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClient.java @@ -0,0 +1,603 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import io.r2dbc.spi.Statement; +import io.r2dbc.spi.Wrapped; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.lang.Nullable; +import org.springframework.r2dbc.connection.ConnectionFactoryUtils; +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindTarget; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + + +/** + * Default implementation of {@link DatabaseClient}. + * + * @author Mark Paluch + * @author Mingyuan Wu + * @author Bogdan Ilchyshyn + * @since 5.3 + */ +class DefaultDatabaseClient implements DatabaseClient { + + private final Log logger = LogFactory.getLog(getClass()); + + private final BindMarkersFactory bindMarkersFactory; + + private final ConnectionFactory connectionFactory; + + private final ExecuteFunction executeFunction; + + private final boolean namedParameters; + + @Nullable + private final NamedParameterExpander namedParameterExpander; + + + DefaultDatabaseClient(BindMarkersFactory bindMarkersFactory, + ConnectionFactory connectionFactory, ExecuteFunction executeFunction, + boolean namedParameters) { + + this.bindMarkersFactory = bindMarkersFactory; + this.connectionFactory = connectionFactory; + this.executeFunction = executeFunction; + this.namedParameters = namedParameters; + this.namedParameterExpander = namedParameters ? new NamedParameterExpander() + : null; + } + + + @Override + public GenericExecuteSpec sql(String sql) { + Assert.hasText(sql, "SQL must not be null or empty"); + return sql(() -> sql); + } + + @Override + public GenericExecuteSpec sql(Supplier sqlSupplier) { + Assert.notNull(sqlSupplier, "SQL Supplier must not be null"); + return new DefaultGenericExecuteSpec(sqlSupplier); + } + + @Override + public Mono inConnection(Function> action) + throws DataAccessException { + Assert.notNull(action, "Callback object must not be null"); + Mono connectionMono = getConnection().map( + connection -> new ConnectionCloseHolder(connection, this::closeConnection)); + + return Mono.usingWhen(connectionMono, connectionCloseHolder -> { + + // Create close-suppressing Connection proxy + Connection connectionToUse = createConnectionProxy(connectionCloseHolder.connection); + + try { + return action.apply(connectionToUse); + } + catch (R2dbcException ex) { + String sql = getSql(action); + return Mono.error(ConnectionFactoryUtils.convertR2dbcException("doInConnection", sql, ex)); + } + }, ConnectionCloseHolder::close, (it, err) -> it.close(), + ConnectionCloseHolder::close) + .onErrorMap(R2dbcException.class, + ex -> ConnectionFactoryUtils.convertR2dbcException("execute", getSql(action), ex)); + } + + @Override + public Flux inConnectionMany(Function> action) + throws DataAccessException { + Assert.notNull(action, "Callback object must not be null"); + Mono connectionMono = getConnection().map( + connection -> new ConnectionCloseHolder(connection, this::closeConnection)); + + return Flux.usingWhen(connectionMono, connectionCloseHolder -> { + + // Create close-suppressing Connection proxy, also preparing returned + // Statements. + Connection connectionToUse = createConnectionProxy(connectionCloseHolder.connection); + + try { + return action.apply(connectionToUse); + } + catch (R2dbcException ex) { + String sql = getSql(action); + return Flux.error(ConnectionFactoryUtils.convertR2dbcException("doInConnectionMany", sql, ex)); + } + }, ConnectionCloseHolder::close, (it, err) -> it.close(), + ConnectionCloseHolder::close) + .onErrorMap(R2dbcException.class, + ex -> ConnectionFactoryUtils.convertR2dbcException("executeMany", getSql(action), ex)); + } + + /** + * Obtain a {@link Connection}. + * @return a {@link Mono} able to emit a {@link Connection} + */ + private Mono getConnection() { + return ConnectionFactoryUtils.getConnection(obtainConnectionFactory()); + } + + /** + * Release the {@link Connection}. + * @param connection to close. + * @return a {@link Publisher} that completes successfully when the connection is + * closed + */ + private Publisher closeConnection(Connection connection) { + + return ConnectionFactoryUtils.currentConnectionFactory( + obtainConnectionFactory()).then().onErrorResume(Exception.class, + e -> Mono.from(connection.close())); + } + + /** + * Obtain the {@link ConnectionFactory} for actual use. + * @return the ConnectionFactory (never {@code null}) + */ + private ConnectionFactory obtainConnectionFactory() { + return this.connectionFactory; + } + + /** + * Create a close-suppressing proxy for the given R2DBC + * Connection. Called by the {@code execute} method. + * @param con the R2DBC Connection to create a proxy for + * @return the Connection proxy + */ + private static Connection createConnectionProxy(Connection con) { + return (Connection) Proxy.newProxyInstance(DatabaseClient.class.getClassLoader(), + new Class[] { Connection.class, Wrapped.class }, + new CloseSuppressingInvocationHandler(con)); + } + + private static Mono sumRowsUpdated( + Function> resultFunction, Connection it) { + return resultFunction.apply(it) + .flatMap(Result::getRowsUpdated) + .collect(Collectors.summingInt(Integer::intValue)); + } + + /** + * Determine SQL from potential provider object. + * @param sqlProvider object that's potentially a SqlProvider + * @return the SQL string, or {@code null} + * @see SqlProvider + */ + @Nullable + private static String getSql(Object sqlProvider) { + + if (sqlProvider instanceof SqlProvider) { + return ((SqlProvider) sqlProvider).getSql(); + } + else { + return null; + } + } + + + /** + * Base class for {@link DatabaseClient.GenericExecuteSpec} implementations. + */ + class DefaultGenericExecuteSpec implements GenericExecuteSpec { + + final Map byIndex; + + final Map byName; + + final Supplier sqlSupplier; + + final StatementFilterFunction filterFunction; + + + DefaultGenericExecuteSpec(Supplier sqlSupplier) { + + this.byIndex = Collections.emptyMap(); + this.byName = Collections.emptyMap(); + this.sqlSupplier = sqlSupplier; + this.filterFunction = StatementFilterFunctions.empty(); + } + + DefaultGenericExecuteSpec(Map byIndex, Map byName, + Supplier sqlSupplier, StatementFilterFunction filterFunction) { + + this.byIndex = byIndex; + this.byName = byName; + this.sqlSupplier = sqlSupplier; + this.filterFunction = filterFunction; + } + + @Override + public DefaultGenericExecuteSpec bind(int index, Object value) { + assertNotPreparedOperation(); + Assert.notNull(value, () -> String.format( + "Value at index %d must not be null. Use bindNull(…) instead.", + index)); + + Map byIndex = new LinkedHashMap<>(this.byIndex); + + if (value instanceof Parameter) { + byIndex.put(index, (Parameter) value); + } + else { + byIndex.put(index, Parameter.fromOrEmpty(value, value.getClass())); + } + + return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction); + } + + @Override + public DefaultGenericExecuteSpec bindNull(int index, Class type) { + assertNotPreparedOperation(); + + Map byIndex = new LinkedHashMap<>(this.byIndex); + byIndex.put(index, Parameter.empty(type)); + + return new DefaultGenericExecuteSpec(byIndex, this.byName, this.sqlSupplier, this.filterFunction); + } + + @Override + public DefaultGenericExecuteSpec bind(String name, Object value) { + assertNotPreparedOperation(); + + Assert.hasText(name, "Parameter name must not be null or empty!"); + Assert.notNull(value, () -> String.format( + "Value for parameter %s must not be null. Use bindNull(…) instead.", + name)); + + Map byName = new LinkedHashMap<>(this.byName); + + if (value instanceof Parameter) { + byName.put(name, (Parameter) value); + } + else { + byName.put(name, Parameter.fromOrEmpty(value, value.getClass())); + } + + return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + } + + @Override + public DefaultGenericExecuteSpec bindNull(String name, Class type) { + assertNotPreparedOperation(); + Assert.hasText(name, "Parameter name must not be null or empty!"); + + Map byName = new LinkedHashMap<>(this.byName); + byName.put(name, Parameter.empty(type)); + + return new DefaultGenericExecuteSpec(this.byIndex, byName, this.sqlSupplier, this.filterFunction); + } + + @Override + public DefaultGenericExecuteSpec filter(StatementFilterFunction filter) { + + Assert.notNull(filter, "Statement FilterFunction must not be null"); + + return new DefaultGenericExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction.andThen(filter)); + } + + @Override + public FetchSpec map(BiFunction mappingFunction) { + Assert.notNull(mappingFunction, "Mapping function must not be null"); + return execute(this.sqlSupplier, mappingFunction); + } + + @Override + public FetchSpec> fetch() { + return execute(this.sqlSupplier, ColumnMapRowMapper.INSTANCE); + } + + @Override + public Mono then() { + return fetch().rowsUpdated().then(); + } + + private FetchSpec execute(Supplier sqlSupplier, + BiFunction mappingFunction) { + + String sql = getRequiredSql(sqlSupplier); + + Function statementFunction = connection -> { + + if (logger.isDebugEnabled()) { + logger.debug("Executing SQL statement [" + sql + "]"); + } + + if (sqlSupplier instanceof PreparedOperation) { + + Statement statement = connection.createStatement(sql); + BindTarget bindTarget = new StatementWrapper(statement); + ((PreparedOperation) sqlSupplier).bindTo(bindTarget); + + return statement; + } + + if (DefaultDatabaseClient.this.namedParameters) { + + Map remainderByName = new LinkedHashMap<>( + this.byName); + Map remainderByIndex = new LinkedHashMap<>( + this.byIndex); + + MapBindParameterSource namedBindings = retrieveParameters(sql, + remainderByName, remainderByIndex); + + PreparedOperation operation = DefaultDatabaseClient.this.namedParameterExpander.expand(sql, + DefaultDatabaseClient.this.bindMarkersFactory, namedBindings); + + String expanded = getRequiredSql(operation); + if (logger.isTraceEnabled()) { + logger.trace("Expanded SQL [" + expanded + "]"); + } + + Statement statement = connection.createStatement(expanded); + BindTarget bindTarget = new StatementWrapper(statement); + + operation.bindTo(bindTarget); + + bindByName(statement, remainderByName); + bindByIndex(statement, remainderByIndex); + + return statement; + } + + Statement statement = connection.createStatement(sql); + + bindByIndex(statement, this.byIndex); + bindByName(statement, this.byName); + + return statement; + }; + + Function> resultFunction = connection -> { + Statement statement = statementFunction.apply(connection); + return Flux.from(this.filterFunction.filter(statement, DefaultDatabaseClient.this.executeFunction)) + .cast(Result.class).checkpoint("SQL \"" + sql + "\" [DatabaseClient]"); + }; + + return new DefaultFetchSpec<>( + DefaultDatabaseClient.this, sql, + new ConnectionFunction<>(sql, resultFunction), + new ConnectionFunction<>(sql, connection -> sumRowsUpdated(resultFunction, connection)), + mappingFunction); + } + + private MapBindParameterSource retrieveParameters(String sql, + Map remainderByName, + Map remainderByIndex) { + List parameterNames = DefaultDatabaseClient.this.namedParameterExpander.getParameterNames(sql); + Map namedBindings = new LinkedHashMap<>( + parameterNames.size()); + for (String parameterName : parameterNames) { + + Parameter parameter = getParameter(remainderByName, remainderByIndex, + parameterNames, parameterName); + + if (parameter == null) { + throw new InvalidDataAccessApiUsageException( + String.format("No parameter specified for [%s] in query [%s]", + parameterName, sql)); + } + + namedBindings.put(parameterName, parameter); + } + return new MapBindParameterSource(namedBindings); + } + + @Nullable + private Parameter getParameter(Map remainderByName, + Map remainderByIndex, List parameterNames, + String parameterName) { + + if (this.byName.containsKey(parameterName)) { + remainderByName.remove(parameterName); + return this.byName.get(parameterName); + } + + int index = parameterNames.indexOf(parameterName); + if (this.byIndex.containsKey(index)) { + remainderByIndex.remove(index); + return this.byIndex.get(index); + } + + return null; + } + + private void assertNotPreparedOperation() { + if (this.sqlSupplier instanceof PreparedOperation) { + throw new InvalidDataAccessApiUsageException( + "Cannot add bindings to a PreparedOperation"); + } + } + + private void bindByName(Statement statement, Map byName) { + byName.forEach((name, parameter) -> { + if (parameter.hasValue()) { + statement.bind(name, parameter.getValue()); + } + else { + statement.bindNull(name, parameter.getType()); + } + }); + } + + private void bindByIndex(Statement statement, Map byIndex) { + byIndex.forEach((i, parameter) -> { + if (parameter.hasValue()) { + statement.bind(i, parameter.getValue()); + } + else { + statement.bindNull(i, parameter.getType()); + } + }); + } + + private String getRequiredSql(Supplier sqlSupplier) { + + String sql = sqlSupplier.get(); + Assert.state(StringUtils.hasText(sql), + "SQL returned by SQL supplier must not be empty!"); + return sql; + } + + } + + + /** + * Invocation handler that suppresses close calls on R2DBC Connections. Also prepares + * returned Statement (Prepared/CallbackStatement) objects. + * + * @see Connection#close() + */ + private static class CloseSuppressingInvocationHandler implements InvocationHandler { + + private final Connection target; + + + CloseSuppressingInvocationHandler(Connection target) { + this.target = target; + } + + @Override + @Nullable + public Object invoke(Object proxy, Method method, Object[] args) + throws Throwable { + // Invocation on ConnectionProxy interface coming in... + + if (method.getName().equals("equals")) { + // Only consider equal when proxies are identical. + return proxy == args[0]; + } + else if (method.getName().equals("hashCode")) { + // Use hashCode of PersistenceManager proxy. + return System.identityHashCode(proxy); + } + else if (method.getName().equals("unwrap")) { + return this.target; + } + else if (method.getName().equals("close")) { + // Handle close method: suppress, not valid. + return Mono.error( + new UnsupportedOperationException("Close is not supported!")); + } + + // Invoke method on target Connection. + try { + return method.invoke(this.target, args); + } + catch (InvocationTargetException ex) { + throw ex.getTargetException(); + } + } + + } + + /** + * Holder for a connection that makes sure the close action is invoked atomically only + * once. + */ + static class ConnectionCloseHolder extends AtomicBoolean { + + private static final long serialVersionUID = -8994138383301201380L; + + final Connection connection; + + final Function> closeFunction; + + + ConnectionCloseHolder(Connection connection, + Function> closeFunction) { + this.connection = connection; + this.closeFunction = closeFunction; + } + + Mono close() { + + return Mono.defer(() -> { + + if (compareAndSet(false, true)) { + return Mono.from(this.closeFunction.apply(this.connection)); + } + + return Mono.empty(); + }); + } + + } + + static class StatementWrapper implements BindTarget { + + final Statement statement; + + + StatementWrapper(Statement statement) { + this.statement = statement; + } + + @Override + public void bind(String identifier, Object value) { + this.statement.bind(identifier, value); + } + + @Override + public void bind(int index, Object value) { + this.statement.bind(index, value); + } + + @Override + public void bindNull(String identifier, Class type) { + this.statement.bindNull(identifier, type); + } + + @Override + public void bindNull(int index, Class type) { + this.statement.bindNull(index, type); + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClientBuilder.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClientBuilder.java new file mode 100644 index 000000000000..add092bc7287 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultDatabaseClientBuilder.java @@ -0,0 +1,106 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Consumer; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Statement; + +import org.springframework.lang.Nullable; +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver; +import org.springframework.util.Assert; + +/** + * Default implementation of {@link DatabaseClient.Builder}. + * + * @author Mark Paluch + * @since 5.3 + */ +class DefaultDatabaseClientBuilder implements DatabaseClient.Builder { + + @Nullable + private BindMarkersFactory bindMarkers; + + @Nullable + private ConnectionFactory connectionFactory; + + private ExecuteFunction executeFunction = Statement::execute; + + private boolean namedParameters = true; + + + DefaultDatabaseClientBuilder() { + } + + + @Override + public DatabaseClient.Builder bindMarkers(BindMarkersFactory bindMarkers) { + Assert.notNull(bindMarkers, "BindMarkersFactory must not be null"); + this.bindMarkers = bindMarkers; + return this; + } + + @Override + public DatabaseClient.Builder connectionFactory(ConnectionFactory factory) { + Assert.notNull(factory, "ConnectionFactory must not be null"); + this.connectionFactory = factory; + return this; + } + + @Override + public DatabaseClient.Builder executeFunction(ExecuteFunction executeFunction) { + Assert.notNull(executeFunction, "ExecuteFunction must not be null"); + this.executeFunction = executeFunction; + return this; + } + + @Override + public DatabaseClient.Builder namedParameters(boolean enabled) { + this.namedParameters = enabled; + return this; + } + + @Override + public DatabaseClient build() { + Assert.notNull(this.connectionFactory, "ConnectionFactory must not be null"); + + BindMarkersFactory bindMarkers = this.bindMarkers; + + if (bindMarkers == null) { + if (this.namedParameters) { + bindMarkers = BindMarkersFactoryResolver.resolve(this.connectionFactory); + } + else { + bindMarkers = BindMarkersFactory.anonymous("?"); + } + } + + return new DefaultDatabaseClient(bindMarkers, this.connectionFactory, + this.executeFunction, this.namedParameters); + } + + @Override + public DatabaseClient.Builder apply( + Consumer builderConsumer) { + Assert.notNull(builderConsumer, "BuilderConsumer must not be null"); + builderConsumer.accept(this); + return this; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java new file mode 100644 index 000000000000..ef2177ce1ee7 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/DefaultFetchSpec.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.BiFunction; +import java.util.function.Function; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.dao.IncorrectResultSizeDataAccessException; + +/** + * Default {@link FetchSpec} implementation. + * + * @author Mark Paluch + * @since 5.3 + * @param the row result type + */ +class DefaultFetchSpec implements FetchSpec { + + private final ConnectionAccessor connectionAccessor; + + private final String sql; + + private final Function> resultFunction; + + private final Function> updatedRowsFunction; + + private final BiFunction mappingFunction; + + + DefaultFetchSpec(ConnectionAccessor connectionAccessor, String sql, + Function> resultFunction, + Function> updatedRowsFunction, + BiFunction mappingFunction) { + this.sql = sql; + this.connectionAccessor = connectionAccessor; + this.resultFunction = resultFunction; + this.updatedRowsFunction = updatedRowsFunction; + this.mappingFunction = mappingFunction; + } + + + @Override + public Mono one() { + return all().buffer(2) + .flatMap(list -> { + + if (list.isEmpty()) { + return Mono.empty(); + } + + if (list.size() > 1) { + return Mono.error(new IncorrectResultSizeDataAccessException( + String.format("Query [%s] returned non unique result.", + this.sql), + 1)); + } + + return Mono.just(list.get(0)); + }).next(); + } + + @Override + public Mono first() { + return all().next(); + } + + @Override + public Flux all() { + return this.connectionAccessor.inConnectionMany(new ConnectionFunction<>(this.sql, + connection -> this.resultFunction.apply(connection) + .flatMap(result -> result.map(this.mappingFunction)))); + } + + @Override + public Mono rowsUpdated() { + return this.connectionAccessor.inConnection(this.updatedRowsFunction); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ExecuteFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ExecuteFunction.java new file mode 100644 index 000000000000..550974e5e4c0 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ExecuteFunction.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.BiFunction; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import org.reactivestreams.Publisher; + +/** + * Represents a function that executes a {@link Statement} for a (delayed) + * {@link Result} stream. + * + *

Note that discarded {@link Result} objects must be consumed according + * to the R2DBC spec via either {@link Result#getRowsUpdated()} or + * {@link Result#map(BiFunction)}. + * + *

Typically, implementations invoke the {@link Statement#execute()} method + * to initiate execution of the statement object. + * + * For example: + *

+ * DatabaseClient.builder()
+ *		.executeFunction(statement -> statement.execute())
+ * 		.build();
+ * 
+ * + * @author Mark Paluch + * @since 5.3 + * @see Statement#execute() + */ +@FunctionalInterface +public interface ExecuteFunction { + + /** + * Execute the given {@link Statement} for a stream of {@link Result}s. + * @param statement the request to execute + * @return the delayed result stream + */ + Publisher execute(Statement statement); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/FetchSpec.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/FetchSpec.java new file mode 100644 index 000000000000..667da7883408 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/FetchSpec.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +/** + * Union type for fetching results. + * + * @author Mark Paluch + * @since 5.3 + * @param the row result type + * @see RowsFetchSpec + * @see UpdatedRowsFetchSpec + */ +public interface FetchSpec extends RowsFetchSpec, UpdatedRowsFetchSpec {} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/MapBindParameterSource.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/MapBindParameterSource.java new file mode 100644 index 000000000000..14c8754fb458 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/MapBindParameterSource.java @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * {@link BindParameterSource} implementation that holds a given {@link Map} of parameters + * encapsulated as {@link Parameter}. + * + *

This class is intended for passing in a simple Map of parameter values to the methods + * of the {@code NamedParameterExpander} class. + * + * @author Mark Paluch + * @since 5.3 + */ +class MapBindParameterSource implements BindParameterSource { + + private final Map values; + + + /** + * Create a new empty {@link MapBindParameterSource}. + */ + MapBindParameterSource() { + this(new LinkedHashMap<>()); + } + + /** + * Creates a new {@link MapBindParameterSource} given {@link Map} of + * {@link Parameter}. + * + * @param values the parameter mapping. + */ + MapBindParameterSource(Map values) { + + Assert.notNull(values, "Values must not be null"); + + this.values = values; + } + + + /** + * Add a key-value pair to the {@link MapBindParameterSource}. The value must not be + * {@code null}. + * + * @param paramName must not be {@code null}. + * @param value must not be {@code null}. + * @return {@code this} {@link MapBindParameterSource} + */ + MapBindParameterSource addValue(String paramName, Object value) { + Assert.notNull(paramName, "Parameter name must not be null"); + Assert.notNull(value, "Value must not be null"); + this.values.put(paramName, Parameter.fromOrEmpty(value, value.getClass())); + return this; + } + + @Override + public boolean hasValue(String paramName) { + Assert.notNull(paramName, "Parameter name must not be null"); + return this.values.containsKey(paramName); + } + + @Override + public Class getType(String paramName) { + Assert.notNull(paramName, "Parameter name must not be null"); + Parameter parameter = this.values.get(paramName); + if (parameter != null) { + return parameter.getType(); + } + return Object.class; + } + + @Override + public Object getValue(String paramName) throws IllegalArgumentException { + if (!hasValue(paramName)) { + throw new IllegalArgumentException( + "No value registered for key '" + paramName + "'"); + } + return this.values.get(paramName).getValue(); + } + + @Override + public Iterable getParameterNames() { + return this.values.keySet(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterExpander.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterExpander.java new file mode 100644 index 000000000000..d630cb4d34c3 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterExpander.java @@ -0,0 +1,162 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.r2dbc.core.binding.BindMarkersFactory; + + +/** + * SQL translation support allowing the use of named parameters + * rather than native placeholders. + * + *

This class expands SQL from named parameters to native + * style placeholders at execution time. It also allows for expanding + * a {@link List} of values to the appropriate number of placeholders. + * + *

References to the same parameter name are substituted with the + * same bind marker placeholder if a {@link BindMarkersFactory} uses + * {@link BindMarkersFactory#identifiablePlaceholders() identifiable} placeholders. + *

NOTE: An instance of this class is thread-safe once configured. + * + * @author Mark Paluch + */ +class NamedParameterExpander { + + /** + * Default maximum number of entries for the SQL cache: 256. + */ + public static final int DEFAULT_CACHE_LIMIT = 256; + + + private volatile int cacheLimit = DEFAULT_CACHE_LIMIT; + + private final Log logger = LogFactory.getLog(getClass()); + + /** + * Cache of original SQL String to ParsedSql representation. + */ + @SuppressWarnings("serial") + private final Map parsedSqlCache = new LinkedHashMap( + DEFAULT_CACHE_LIMIT, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > getCacheLimit(); + } + }; + + + /** + * Create a new enabled instance of {@link NamedParameterExpander}. + */ + public NamedParameterExpander() {} + + + /** + * Specify the maximum number of entries for the SQL cache. Default is 256. + */ + public void setCacheLimit(int cacheLimit) { + this.cacheLimit = cacheLimit; + } + + /** + * Return the maximum number of entries for the SQL cache. + */ + public int getCacheLimit() { + return this.cacheLimit; + } + + /** + * Obtain a parsed representation of the given SQL statement. + *

+ * The default implementation uses an LRU cache with an upper limit of 256 entries. + * + * @param sql the original SQL statement + * @return a representation of the parsed SQL statement + */ + private ParsedSql getParsedSql(String sql) { + + if (getCacheLimit() <= 0) { + return NamedParameterUtils.parseSqlStatement(sql); + } + + synchronized (this.parsedSqlCache) { + + ParsedSql parsedSql = this.parsedSqlCache.get(sql); + if (parsedSql == null) { + + parsedSql = NamedParameterUtils.parseSqlStatement(sql); + this.parsedSqlCache.put(sql, parsedSql); + } + return parsedSql; + } + } + + /** + * Parse the SQL statement and locate any placeholders or named parameters. + * Named parameters are substituted for a native placeholder, and any + * select list is expanded to the required number of placeholders. Select + * lists may contain an array of objects, and in that case the placeholders + * will be grouped and enclosed with parentheses. This allows for the use of + * "expression lists" in the SQL statement like: + * + *

+	 * select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))
+	 * 
+ * + *

The parameter values passed in are used to determine the number of + * placeholders to be used for a select list. Select lists should be limited + * to 100 or fewer elements. A larger number of elements is not guaranteed to be + * supported by the database and is strictly vendor-dependent. + * @param sql sql the original SQL statement + * @param bindMarkersFactory the bind marker factory + * @param paramSource the source for named parameters + * @return the expanded sql that accepts bind parameters and allows for execution + * without further translation wrapped as {@link PreparedOperation}. + */ + public PreparedOperation expand(String sql, BindMarkersFactory bindMarkersFactory, + BindParameterSource paramSource) { + + ParsedSql parsedSql = getParsedSql(sql); + + PreparedOperation expanded = NamedParameterUtils.substituteNamedParameters(parsedSql, bindMarkersFactory, + paramSource); + + if (logger.isDebugEnabled()) { + logger.debug(String.format("Expanding SQL statement [%s] to [%s]", sql, expanded.toQuery())); + } + + return expanded; + } + + /** + * Parse the SQL statement and locate any placeholders or named parameters. Named parameters are returned as result of + * this method invocation. + * + * @return the parameter names. + */ + public List getParameterNames(String sql) { + return getParsedSql(sql).getParameterNames(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java new file mode 100644 index 000000000000..fe988f1b349c --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/NamedParameterUtils.java @@ -0,0 +1,630 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; + +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.lang.Nullable; +import org.springframework.r2dbc.core.binding.BindMarker; +import org.springframework.r2dbc.core.binding.BindMarkers; +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindTarget; +import org.springframework.util.Assert; + +/** + * Helper methods for named parameter parsing. + * + *

Only intended for internal use within Spring's R2DBC + * framework. + * + *

References to the same parameter name are substituted with + * the same bind marker placeholder if a {@link BindMarkersFactory} uses + * {@link BindMarkersFactory#identifiablePlaceholders() identifiable} + * placeholders. + * + * @author Thomas Risberg + * @author Juergen Hoeller + * @author Mark Paluch + * @since 5.3 + */ +abstract class NamedParameterUtils { + + /** + * Set of characters that qualify as comment or quotes starting characters. + */ + private static final String[] START_SKIP = new String[] {"'", "\"", "--", "/*"}; + + /** + * Set of characters that at are the corresponding comment or quotes ending characters. + */ + private static final String[] STOP_SKIP = new String[] {"'", "\"", "\n", "*/"}; + + /** + * Set of characters that qualify as parameter separators, + * indicating that a parameter name in an SQL String has ended. + */ + private static final String PARAMETER_SEPARATORS = "\"':&,;()|=+-*%/\\<>^"; + + /** + * An index with separator flags per character code. + * Technically only needed between 34 and 124 at this point. + */ + private static final boolean[] separatorIndex = new boolean[128]; + + static { + for (char c : PARAMETER_SEPARATORS.toCharArray()) { + separatorIndex[c] = true; + } + } + + + // ------------------------------------------------------------------------- + // Core methods used by NamedParameterSupport. + // ------------------------------------------------------------------------- + + /** + * Parse the SQL statement and locate any placeholders or named parameters. + * Namedparameters are substituted for a R2DBC placeholder. + * + * @param sql the SQL statement + * @return the parsed statement, represented as {@link ParsedSql} instance. + */ + public static ParsedSql parseSqlStatement(String sql) { + Assert.notNull(sql, "SQL must not be null"); + + Set namedParameters = new HashSet<>(); + String sqlToUse = sql; + List parameterList = new ArrayList<>(); + + char[] statement = sql.toCharArray(); + int namedParameterCount = 0; + int unnamedParameterCount = 0; + int totalParameterCount = 0; + + int escapes = 0; + int i = 0; + while (i < statement.length) { + int skipToPosition = i; + while (i < statement.length) { + skipToPosition = skipCommentsAndQuotes(statement, i); + if (i == skipToPosition) { + break; + } + else { + i = skipToPosition; + } + } + if (i >= statement.length) { + break; + } + char c = statement[i]; + if (c == ':' || c == '&') { + int j = i + 1; + if (c == ':' && j < statement.length && statement[j] == ':') { + // Postgres-style "::" casting operator should be skipped + i = i + 2; + continue; + } + String parameter = null; + if (c == ':' && j < statement.length && statement[j] == '{') { + // :{x} style parameter + while (statement[j] != '}') { + j++; + if (j >= statement.length) { + throw new InvalidDataAccessApiUsageException("Non-terminated named parameter declaration " + + "at position " + i + " in statement: " + sql); + } + if (statement[j] == ':' || statement[j] == '{') { + throw new InvalidDataAccessApiUsageException("Parameter name contains invalid character '" + + statement[j] + "' at position " + i + " in statement: " + sql); + } + } + if (j - i > 2) { + parameter = sql.substring(i + 2, j); + namedParameterCount = addNewNamedParameter(namedParameters, + namedParameterCount, parameter); + totalParameterCount = addNamedParameter(parameterList, + totalParameterCount, escapes, i, j + 1, parameter); + } + j++; + } + else { + while (j < statement.length && !isParameterSeparator(statement[j])) { + j++; + } + if (j - i > 1) { + parameter = sql.substring(i + 1, j); + namedParameterCount = addNewNamedParameter(namedParameters, + namedParameterCount, parameter); + totalParameterCount = addNamedParameter(parameterList, + totalParameterCount, escapes, i, j, parameter); + } + } + i = j - 1; + } + else { + if (c == '\\') { + int j = i + 1; + if (j < statement.length && statement[j] == ':') { + // escaped ":" should be skipped + sqlToUse = sqlToUse.substring(0, i - escapes) + + sqlToUse.substring(i - escapes + 1); + escapes++; + i = i + 2; + continue; + } + } + } + i++; + } + ParsedSql parsedSql = new ParsedSql(sqlToUse); + for (ParameterHolder ph : parameterList) { + parsedSql.addNamedParameter(ph.getParameterName(), ph.getStartIndex(), ph.getEndIndex()); + } + parsedSql.setNamedParameterCount(namedParameterCount); + parsedSql.setUnnamedParameterCount(unnamedParameterCount); + parsedSql.setTotalParameterCount(totalParameterCount); + return parsedSql; + } + + private static int addNamedParameter( + List parameterList, int totalParameterCount, int escapes, int i, int j, String parameter) { + + parameterList.add(new ParameterHolder(parameter, i - escapes, j - escapes)); + totalParameterCount++; + return totalParameterCount; + } + + private static int addNewNamedParameter(Set namedParameters, int namedParameterCount, String parameter) { + if (!namedParameters.contains(parameter)) { + namedParameters.add(parameter); + namedParameterCount++; + } + return namedParameterCount; + } + + /** + * Skip over comments and quoted names present in an SQL statement. + * @param statement character array containing SQL statement + * @param position current position of statement + * @return next position to process after any comments or quotes are skipped + */ + private static int skipCommentsAndQuotes(char[] statement, int position) { + for (int i = 0; i < START_SKIP.length; i++) { + if (statement[position] == START_SKIP[i].charAt(0)) { + boolean match = true; + for (int j = 1; j < START_SKIP[i].length(); j++) { + if (statement[position + j] != START_SKIP[i].charAt(j)) { + match = false; + break; + } + } + if (match) { + int offset = START_SKIP[i].length(); + for (int m = position + offset; m < statement.length; m++) { + if (statement[m] == STOP_SKIP[i].charAt(0)) { + boolean endMatch = true; + int endPos = m; + for (int n = 1; n < STOP_SKIP[i].length(); n++) { + if (m + n >= statement.length) { + // last comment not closed properly + return statement.length; + } + if (statement[m + n] != STOP_SKIP[i].charAt(n)) { + endMatch = false; + break; + } + endPos = m + n; + } + if (endMatch) { + // found character sequence ending comment or quote + return endPos + 1; + } + } + } + // character sequence ending comment or quote not found + return statement.length; + } + } + } + return position; + } + + /** + * Parse the SQL statement and locate any placeholders or named parameters. Named + * parameters are substituted for a R2DBC placeholder, and any select list is expanded + * to the required number of placeholders. Select lists may contain an array of + * objects, and in that case the placeholders will be grouped and enclosed with + * parentheses. This allows for the use of "expression lists" in the SQL statement + * like:

+ * {@code select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))} + *

The parameter values passed in are used to determine the number of placeholders to + * be used for a select list. Select lists should be limited to 100 or fewer elements. + * A larger number of elements is not guaranteed to be supported by the database and + * is strictly vendor-dependent. + * @param parsedSql the parsed representation of the SQL statement + * @param bindMarkersFactory the bind marker factory. + * @param paramSource the source for named parameters + * @return the expanded query that accepts bind parameters and allows for execution + * without further translation + * @see #parseSqlStatement + */ + public static PreparedOperation substituteNamedParameters(ParsedSql parsedSql, + BindMarkersFactory bindMarkersFactory, BindParameterSource paramSource) { + NamedParameters markerHolder = new NamedParameters(bindMarkersFactory); + String originalSql = parsedSql.getOriginalSql(); + List paramNames = parsedSql.getParameterNames(); + if (paramNames.isEmpty()) { + return new ExpandedQuery(originalSql, markerHolder, paramSource); + } + StringBuilder actualSql = new StringBuilder(originalSql.length()); + int lastIndex = 0; + for (int i = 0; i < paramNames.size(); i++) { + String paramName = paramNames.get(i); + int[] indexes = parsedSql.getParameterIndexes(i); + int startIndex = indexes[0]; + int endIndex = indexes[1]; + actualSql.append(originalSql, lastIndex, startIndex); + NamedParameters.NamedParameter marker = markerHolder.getOrCreate(paramName); + if (paramSource.hasValue(paramName)) { + Object value = paramSource.getValue(paramName); + if (value instanceof Collection) { + + Iterator entryIter = ((Collection) value).iterator(); + int k = 0; + int counter = 0; + while (entryIter.hasNext()) { + if (k > 0) { + actualSql.append(", "); + } + k++; + Object entryItem = entryIter.next(); + if (entryItem instanceof Object[]) { + Object[] expressionList = (Object[]) entryItem; + actualSql.append('('); + for (int m = 0; m < expressionList.length; m++) { + if (m > 0) { + actualSql.append(", "); + } + actualSql.append(marker.getPlaceholder(counter)); + counter++; + } + actualSql.append(')'); + } + else { + actualSql.append(marker.getPlaceholder(counter)); + counter++; + } + + } + } + else { + actualSql.append(marker.getPlaceholder()); + } + } + else { + actualSql.append(marker.getPlaceholder()); + } + lastIndex = endIndex; + } + actualSql.append(originalSql, lastIndex, originalSql.length()); + + return new ExpandedQuery(actualSql.toString(), markerHolder, paramSource); + } + + /** + * Determine whether a parameter name ends at the current position, + * that is, whether the given character qualifies as a separator. + */ + private static boolean isParameterSeparator(char c) { + return (c < 128 && separatorIndex[c]) || Character.isWhitespace(c); + } + + // ------------------------------------------------------------------------- + // Convenience methods operating on a plain SQL String + // ------------------------------------------------------------------------- + + /** + * Parse the SQL statement and locate any placeholders or named parameters. + * Named parameters are substituted for a native placeholder and any + * select list is expanded to the required number of placeholders. + * @param sql the SQL statement + * @param bindMarkersFactory the bind marker factory + * @param paramSource the source for named parameters + * @return the expanded query that accepts bind parameters and allows for execution + * without further translation + */ + public static PreparedOperation substituteNamedParameters(String sql, + BindMarkersFactory bindMarkersFactory, BindParameterSource paramSource) { + ParsedSql parsedSql = parseSqlStatement(sql); + return substituteNamedParameters(parsedSql, bindMarkersFactory, paramSource); + } + + + private static class ParameterHolder { + + private final String parameterName; + + private final int startIndex; + + private final int endIndex; + + ParameterHolder(String parameterName, int startIndex, int endIndex) { + this.parameterName = parameterName; + this.startIndex = startIndex; + this.endIndex = endIndex; + } + + String getParameterName() { + return this.parameterName; + } + + int getStartIndex() { + return this.startIndex; + } + + int getEndIndex() { + return this.endIndex; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ParameterHolder)) { + return false; + } + ParameterHolder that = (ParameterHolder) o; + return this.startIndex == that.startIndex && this.endIndex == that.endIndex + && Objects.equals(this.parameterName, that.parameterName); + } + + @Override + public int hashCode() { + return Objects.hash(this.parameterName, this.startIndex, this.endIndex); + } + + } + + /** + * Holder for bind markers progress. + */ + static class NamedParameters { + + private final BindMarkers bindMarkers; + + private final boolean identifiable; + + private final Map> references = new TreeMap<>(); + + NamedParameters(BindMarkersFactory factory) { + this.bindMarkers = factory.create(); + this.identifiable = factory.identifiablePlaceholders(); + } + + + /** + * Get the {@link NamedParameter} identified by {@code namedParameter}. + * Parameter objects get created if they do not yet exist. + * @param namedParameter the parameter name + * @return the named parameter + */ + NamedParameter getOrCreate(String namedParameter) { + + List reference = this.references.computeIfAbsent( + namedParameter, ignore -> new ArrayList<>()); + + if (reference.isEmpty()) { + NamedParameter param = new NamedParameter(namedParameter); + reference.add(param); + return param; + } + + if (this.identifiable) { + return reference.get(0); + } + + NamedParameter param = new NamedParameter(namedParameter); + reference.add(param); + return param; + } + + @Nullable + List getMarker(String name) { + return this.references.get(name); + } + + class NamedParameter { + + private final String namedParameter; + + private final List placeholders = new ArrayList<>(); + + + NamedParameter(String namedParameter) { + this.namedParameter = namedParameter; + } + + /** + * Create a placeholder to translate a single value into a bindable parameter. + *

Can be called multiple times to create placeholders for array/collections. + * @return the placeholder to be used in the SQL statement + */ + String addPlaceholder() { + + BindMarker bindMarker = NamedParameters.this.bindMarkers.next( + this.namedParameter); + this.placeholders.add(bindMarker); + return bindMarker.getPlaceholder(); + } + + String getPlaceholder() { + return getPlaceholder(0); + } + + String getPlaceholder(int counter) { + + while (counter + 1 > this.placeholders.size()) { + addPlaceholder(); + } + + return this.placeholders.get(counter).getPlaceholder(); + } + } + + } + + /** + * Expanded query that allows binding of parameters using parameter names that were + * used to expand the query. Binding unrolls {@link Collection}s and nested arrays. + */ + private static class ExpandedQuery implements PreparedOperation { + + private final String expandedSql; + + private final NamedParameters parameters; + + private final BindParameterSource parameterSource; + + + ExpandedQuery(String expandedSql, NamedParameters parameters, + BindParameterSource parameterSource) { + this.expandedSql = expandedSql; + this.parameters = parameters; + this.parameterSource = parameterSource; + } + + + @SuppressWarnings("unchecked") + public void bind(BindTarget target, String identifier, Object value) { + + List bindMarkers = getBindMarkers(identifier); + + if (bindMarkers == null) { + target.bind(identifier, value); + return; + } + + if (value instanceof Collection) { + Collection collection = (Collection) value; + + Iterator iterator = collection.iterator(); + Iterator markers = bindMarkers.iterator(); + + while (iterator.hasNext()) { + + Object valueToBind = iterator.next(); + + if (valueToBind instanceof Object[]) { + Object[] objects = (Object[]) valueToBind; + for (Object object : objects) { + bind(target, markers, object); + } + } + else { + bind(target, markers, valueToBind); + } + } + } + else { + for (BindMarker bindMarker : bindMarkers) { + bindMarker.bind(target, value); + } + } + } + + private void bind(BindTarget target, Iterator markers, + Object valueToBind) { + + Assert.isTrue(markers.hasNext(), () -> String.format( + "No bind marker for value [%s] in SQL [%s]. Check that the query was expanded using the same arguments.", + valueToBind, toQuery())); + + markers.next().bind(target, valueToBind); + } + + public void bindNull(BindTarget target, String identifier, Class valueType) { + List bindMarkers = getBindMarkers(identifier); + + if (bindMarkers == null) { + target.bindNull(identifier, valueType); + return; + } + + for (BindMarker bindMarker : bindMarkers) { + bindMarker.bindNull(target, valueType); + } + } + + @Nullable + List getBindMarkers(String identifier) { + List parameters = this.parameters.getMarker( + identifier); + + if (parameters == null) { + return null; + } + + List markers = new ArrayList<>(); + + for (NamedParameters.NamedParameter parameter : parameters) { + markers.addAll(parameter.placeholders); + } + + return markers; + } + + @Override + public String getSource() { + return this.expandedSql; + } + + @Override + public void bindTo(BindTarget target) { + + for (String namedParameter : this.parameterSource.getParameterNames()) { + + Object value = this.parameterSource.getValue(namedParameter); + + if (value == null) { + bindNull(target, namedParameter, + this.parameterSource.getType(namedParameter)); + } + else { + bind(target, namedParameter, value); + } + } + } + + @Override + public String toQuery() { + return this.expandedSql; + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/Parameter.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/Parameter.java new file mode 100644 index 000000000000..f74d9b349cb4 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/Parameter.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.Objects; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +/** + * A database value that can be set in a statement. + * + * @author Mark Paluch + * @since 5.3 + */ +public final class Parameter { + + @Nullable + private final Object value; + + private final Class type; + + private Parameter(@Nullable Object value, Class type) { + Assert.notNull(type, "Type must not be null"); + this.value = value; + this.type = type; + } + + + /** + * Create a new {@link Parameter} from {@code value}. + * @param value must not be {@code null} + * @return the {@link Parameter} value for {@code value} + */ + public static Parameter from(Object value) { + Assert.notNull(value, "Value must not be null"); + return new Parameter(value, ClassUtils.getUserClass(value)); + } + + /** + * Create a new {@link Parameter} from {@code value} and {@code type}. + * @param value can be {@code null} + * @param type must not be {@code null} + * @return the {@link Parameter} value for {@code value} + */ + public static Parameter fromOrEmpty(@Nullable Object value, Class type) { + return value == null ? empty(type) : new Parameter(value, ClassUtils.getUserClass(value)); + } + + /** + * Create a new empty {@link Parameter} for {@code type}. + * @return the empty {@link Parameter} value for {@code type} + */ + public static Parameter empty(Class type) { + Assert.notNull(type, "Type must not be null"); + return new Parameter(null, type); + } + + + /** + * Returns the column value. Can be {@code null}. + * @return the column value. Can be {@code null} + * @see #hasValue() + */ + @Nullable + public Object getValue() { + return this.value; + } + + /** + * Returns the column value type. Must be also present if the {@code value} is {@code null}. + * @return the column value type + */ + public Class getType() { + return this.type; + } + + /** + * Returns whether this {@link Parameter} has a value. + * @return whether this {@link Parameter} has a value. {@code false} if {@link #getValue()} is {@code null} + */ + public boolean hasValue() { + return this.value != null; + } + + /** + * Returns whether this {@link Parameter} has a empty. + * @return whether this {@link Parameter} is empty. {@code true} if {@link #getValue()} is {@code null} + */ + public boolean isEmpty() { + return this.value == null; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Parameter)) { + return false; + } + Parameter other = (Parameter) o; + return ObjectUtils.nullSafeEquals(this.value, other.value) && ObjectUtils.nullSafeEquals(this.type, other.type); + } + + @Override + public int hashCode() { + return Objects.hash(this.value, this.type); + } + + @Override + public String toString() { + StringBuffer sb = new StringBuffer(); + sb.append("Parameter"); + sb.append("[value=").append(this.value); + sb.append(", type=").append(this.type); + sb.append(']'); + return sb.toString(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ParsedSql.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ParsedSql.java new file mode 100644 index 000000000000..818cae130ec3 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/ParsedSql.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.ArrayList; +import java.util.List; + +/** + * Holds information about a parsed SQL statement. + * + * @author Thomas Risberg + * @author Juergen Hoeller + * @since 5.3 + */ +class ParsedSql { + + private String originalSql; + + private List parameterNames = new ArrayList<>(); + + private List parameterIndexes = new ArrayList<>(); + + private int namedParameterCount; + + private int unnamedParameterCount; + + private int totalParameterCount; + + + /** + * Create a new instance of the {@link ParsedSql} class. + * @param originalSql the SQL statement that is being (or is to be) parsed + */ + ParsedSql(String originalSql) { + this.originalSql = originalSql; + } + + /** + * Return the SQL statement that is being parsed. + */ + String getOriginalSql() { + return this.originalSql; + } + + + /** + * Add a named parameter parsed from this SQL statement. + * @param parameterName the name of the parameter + * @param startIndex the start index in the original SQL String + * @param endIndex the end index in the original SQL String + */ + void addNamedParameter(String parameterName, int startIndex, int endIndex) { + this.parameterNames.add(parameterName); + this.parameterIndexes.add(new int[] {startIndex, endIndex}); + } + + /** + * Return all of the parameters (bind variables) in the parsed SQL statement. + * Repeated occurrences of the same parameter name are included here. + */ + List getParameterNames() { + return this.parameterNames; + } + + /** + * Return the parameter indexes for the specified parameter. + * @param parameterPosition the position of the parameter + * (as index in the parameter names List) + * @return the start index and end index, combined into + * a int array of length 2 + */ + int[] getParameterIndexes(int parameterPosition) { + return this.parameterIndexes.get(parameterPosition); + } + + /** + * Set the count of named parameters in the SQL statement. + * Each parameter name counts once; repeated occurrences do not count here. + */ + void setNamedParameterCount(int namedParameterCount) { + this.namedParameterCount = namedParameterCount; + } + + /** + * Return the count of named parameters in the SQL statement. + * Each parameter name counts once; repeated occurrences do not count here. + */ + int getNamedParameterCount() { + return this.namedParameterCount; + } + + /** + * Set the count of all of the unnamed parameters in the SQL statement. + */ + void setUnnamedParameterCount(int unnamedParameterCount) { + this.unnamedParameterCount = unnamedParameterCount; + } + + /** + * Return the count of all of the unnamed parameters in the SQL statement. + */ + int getUnnamedParameterCount() { + return this.unnamedParameterCount; + } + + /** + * Set the total count of all of the parameters in the SQL statement. + * Repeated occurrences of the same parameter name do count here. + */ + void setTotalParameterCount(int totalParameterCount) { + this.totalParameterCount = totalParameterCount; + } + + /** + * Return the total count of all of the parameters in the SQL statement. + * Repeated occurrences of the same parameter name do count here. + */ + int getTotalParameterCount() { + return this.totalParameterCount; + } + + + /** + * Exposes the original SQL String. + */ + @Override + public String toString() { + return this.originalSql; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/PreparedOperation.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/PreparedOperation.java new file mode 100644 index 000000000000..83e0ac0aa087 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/PreparedOperation.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Supplier; + +import org.springframework.r2dbc.core.binding.BindTarget; + +/** + * Extension to {@link QueryOperation} for a prepared SQL query + * {@link Supplier} with bound parameters. Contains parameter + * bindings that can be {@link #bindTo bound} bound to a {@link BindTarget}. + *

Can be executed with {@link org.springframework.r2dbc.core.DatabaseClient}. + * + * @author Mark Paluch + * @since 5.3 + * @param underlying operation source. + * @see org.springframework.r2dbc.core.DatabaseClient#sql(Supplier) + */ +public interface PreparedOperation extends QueryOperation { + + /** + * Return the underlying query source. + * @return the query source, such as a statement/criteria object. + */ + T getSource(); + + /** + * Apply bindings to {@link BindTarget}. + * @param target the target to apply bindings to. + */ + void bindTo(BindTarget target); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/QueryOperation.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/QueryOperation.java new file mode 100644 index 000000000000..a9ee62a28064 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/QueryOperation.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.function.Supplier; + +/** + * Interface declaring a query operation that can be represented + * with a query string. This interface is typically implemented + * by classes representing a SQL operation such as {@code SELECT}, + * {@code INSERT}, and such. + * + * @author Mark Paluch + * @since 5.3 + * @see PreparedOperation + */ +@FunctionalInterface +public interface QueryOperation extends Supplier { + + /** + * Returns the string-representation of this operation to + * be used with {@link io.r2dbc.spi.Statement} creation. + * @return the operation as SQL string + * @see io.r2dbc.spi.Connection#createStatement(String) + */ + String toQuery(); + + @Override + default String get() { + return toQuery(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/RowsFetchSpec.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/RowsFetchSpec.java new file mode 100644 index 000000000000..4307858a10aa --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/RowsFetchSpec.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Contract for fetching tabular results. + * + * @author Mark Paluch + * @since 5.3 + * @param the row result type + */ +public interface RowsFetchSpec { + + /** + * Get exactly zero or one result. + * + * @return a mono emitting one element. {@link Mono#empty()} if no match found. + * Completes with {@code IncorrectResultSizeDataAccessException} if more than one match found + */ + Mono one(); + + /** + * Get the first or no result. + * @return a mono emitting the first element. {@link Mono#empty()} if no match found + */ + Mono first(); + + /** + * Get all matching elements. + * @return a flux emitting all results + */ + Flux all(); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/SqlProvider.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/SqlProvider.java new file mode 100644 index 000000000000..731142d7f41d --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/SqlProvider.java @@ -0,0 +1,41 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import org.springframework.lang.Nullable; + +/** + * Interface to be implemented by objects that can provide SQL strings. + * + *

Typically implemented by objects that want to expose the SQL they + * use to create their statements, to allow for better contextual + * information in case of exceptions. + * + * @author Mark Paluch + * @since 5.3 + */ +public interface SqlProvider { + + /** + * Return the SQL string for this object, i.e. + * typically the SQL used for creating statements. + * @return the SQL string, or {@code null} + */ + @Nullable + String getSql(); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunction.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunction.java new file mode 100644 index 000000000000..281c22574adc --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunction.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import org.reactivestreams.Publisher; + +import org.springframework.util.Assert; + +/** + * Represents a function that filters an {@link ExecuteFunction execute function}. + *

The filter is executed when a {@link org.reactivestreams.Subscriber} subscribes + * to the {@link Publisher} returned by the {@link DatabaseClient}. + *

StatementFilterFunctions are typically used to specify additional details on + * the Statement objects such as {@code fetchSize} or key generation. + * + * @author Mark Paluch + * @since 5.3 + * @see ExecuteFunction + */ +@FunctionalInterface +public interface StatementFilterFunction { + + /** + * Apply this filter to the given {@link Statement} and {@link ExecuteFunction}. + *

The given {@link ExecuteFunction} represents the next entity in the chain, + * to be invoked via {@link ExecuteFunction#execute(Statement)} invoked} in + * order to proceed with the execution, or not invoked to shortcut the chain. + * @param statement the current {@link Statement} + * @param next the next execute function in the chain + * @return the filtered {@link Result}s. + */ + Publisher filter(Statement statement, ExecuteFunction next); + + /** + * Return a composed filter function that first applies this filter, and then + * applies the given {@code "after"} filter. + * @param afterFilter the filter to apply after this filter + * @return the composed filter. + */ + default StatementFilterFunction andThen(StatementFilterFunction afterFilter) { + Assert.notNull(afterFilter, "StatementFilterFunction must not be null"); + return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next)); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunctions.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunctions.java new file mode 100644 index 000000000000..c79c16a2c809 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/StatementFilterFunctions.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import org.reactivestreams.Publisher; + +/** + * Collection of default {@link StatementFilterFunction}s. + * + * @author Mark Paluch + * @since 5.3 + */ +enum StatementFilterFunctions implements StatementFilterFunction { + + EMPTY_FILTER; + + + @Override + public Publisher filter(Statement statement, ExecuteFunction next) { + return next.execute(statement); + } + + /** + * Return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}. + * @return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}. + */ + public static StatementFilterFunction empty() { + return EMPTY_FILTER; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/UpdatedRowsFetchSpec.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/UpdatedRowsFetchSpec.java new file mode 100644 index 000000000000..b9b20fbb2c31 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/UpdatedRowsFetchSpec.java @@ -0,0 +1,35 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import reactor.core.publisher.Mono; + +/** + * Contract for fetching the number of affected rows. + * + * @author Mark Paluch + * @since 5.3 + */ +public interface UpdatedRowsFetchSpec { + + /** + * Get the number of updated rows. + * @return a mono emitting the number of updated rows + */ + Mono rowsUpdated(); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkers.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkers.java new file mode 100644 index 000000000000..60cae994f324 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkers.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Anonymous, index-based bind marker using a static placeholder. + * Instances are bound by the ordinal position ordered by the appearance of + * the placeholder. This implementation creates indexed bind markers using + * an anonymous placeholder that correlates with an index. + * + *

Note: Anonymous bind markers are problematic because the have to appear + * in generated SQL in the same order they get generated. This might cause + * challenges in the future with complex generate statements. For example those + * containing subselects which limit the freedom of arranging bind markers. + * + * @author Mark Paluch + * @since 5.3 + */ +class AnonymousBindMarkers implements BindMarkers { + + private static final AtomicIntegerFieldUpdater COUNTER_INCREMENTER = AtomicIntegerFieldUpdater + .newUpdater(AnonymousBindMarkers.class, "counter"); + + + private final String placeholder; + + // access via COUNTER_INCREMENTER + @SuppressWarnings("unused") + private volatile int counter = 0; + + + /** + * Create a new {@link AnonymousBindMarkers} instance given {@code placeholder}. + * @param placeholder parameter bind marker + */ + AnonymousBindMarkers(String placeholder) { + this.placeholder = placeholder; + } + + + @Override + public BindMarker next() { + int index = COUNTER_INCREMENTER.getAndIncrement(this); + return new IndexedBindMarkers.IndexedBindMarker(this.placeholder, index); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarker.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarker.java new file mode 100644 index 000000000000..71687186a81b --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarker.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import io.r2dbc.spi.Statement; + +/** + * A bind marker represents a single bindable parameter within a query. + * Bind markers are dialect-specific and provide a + * {@link #getPlaceholder() placeholder} that is used in the actual query. + * + * @author Mark Paluch + * @since 5.3 + * @see Statement#bind + * @see BindMarkers + * @see BindMarkersFactory + */ +public interface BindMarker { + + /** + * Returns the database-specific placeholder for a given substitution. + */ + String getPlaceholder(); + + /** + * Bind the given {@code value} to the {@link Statement} using the underlying binding strategy. + * + * @param bindTarget the target to bind the value to + * @param value the actual value. Must not be {@code null} + * Use {@link #bindNull(BindTarget, Class)} for {@code null} values + * @see Statement#bind + */ + void bind(BindTarget bindTarget, Object value); + + /** + * Bind a {@code null} value to the {@link Statement} using the underlying binding strategy. + * @param bindTarget the target to bind the value to + * @param valueType value type, must not be {@code null} + * @see Statement#bindNull + */ + void bindNull(BindTarget bindTarget, Class valueType); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkers.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkers.java new file mode 100644 index 000000000000..13e2537198bc --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkers.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +/** + * Bind markers represent placeholders in SQL queries for substitution + * for an actual parameter. Using bind markers allows creating safe queries + * so query strings are not required to contain escaped values but rather + * the driver encodes parameter in the appropriate representation. + * + *

{@link BindMarkers} is stateful and can be only used for a single binding + * pass of one or more parameters. It maintains bind indexes/bind parameter names. + * + * @author Mark Paluch + * @since 5.3 + * @see BindMarker + * @see BindMarkersFactory + * @see io.r2dbc.spi.Statement#bind + */ +@FunctionalInterface +public interface BindMarkers { + + /** + * Create a new {@link BindMarker}. + * @return a new {@link BindMarker} + */ + BindMarker next(); + + /** + * Create a new {@link BindMarker} that accepts a {@code hint}. + * Implementations are allowed to consider/ignore/filter + * the name hint to create more expressive bind markers. + * @param hint an optional name hint that can be used as part of the bind marker + * @return a new {@link BindMarker} + */ + default BindMarker next(String hint) { + return next(); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactory.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactory.java new file mode 100644 index 000000000000..3281dd00a534 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactory.java @@ -0,0 +1,149 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.function.Function; + +import org.springframework.util.Assert; + +/** + * This class creates new {@link BindMarkers} instances to bind + * parameter to a specific {@link io.r2dbc.spi.Statement}. + * + *

Bind markers can be typically represented as placeholder and identifier. + * Placeholders are used within the query to execute so the underlying database + * system can substitute the placeholder with the actual value. Identifiers + * are used in R2DBC drivers to bind a value to a bind marker. Identifiers are + * typically a part of an entire bind marker when using indexed or named bind markers. + * + * @author Mark Paluch + * @since 5.3 + * @see BindMarkers + * @see io.r2dbc.spi.Statement + */ +@FunctionalInterface +public interface BindMarkersFactory { + + /** + * Create a new {@link BindMarkers} instance. + * @return a new {@link BindMarkers} instance + */ + BindMarkers create(); + + /** + * Return whether the {@link BindMarkersFactory} uses identifiable + * placeholders. + * @return whether the {@link BindMarkersFactory} uses identifiable + * placeholders. {@code false} if multiple placeholders cannot be + * distinguished by just the {@link BindMarker#getPlaceholder() placeholder} + * identifier. + */ + default boolean identifiablePlaceholders() { + return true; + } + + + // Static, factory methods + + /** + * Create index-based {@link BindMarkers} using indexes to bind parameters. + * Allows customization of the bind marker placeholder {@code prefix} to + * represent the bind marker as placeholder within the query. + * @param prefix bind parameter prefix that is included in + * {@link BindMarker#getPlaceholder()} but not the actual identifier + * @param beginWith the first index to use + * @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith} + * @see io.r2dbc.spi.Statement#bindNull(int, Class) + * @see io.r2dbc.spi.Statement#bind(int, Object) + */ + static BindMarkersFactory indexed(String prefix, int beginWith) { + Assert.notNull(prefix, "Prefix must not be null"); + return () -> new IndexedBindMarkers(prefix, beginWith); + } + + /** + * Create anonymous, index-based bind marker using a static placeholder. + * Instances are bound by the ordinal position ordered by the appearance + * of the placeholder. This implementation creates indexed bind markers + * using an anonymous placeholder that correlates with an index. + * @param placeholder parameter placeholder + * @return a {@link BindMarkersFactory} using {@code placeholder} + * @see io.r2dbc.spi.Statement#bindNull(int, Class) + * @see io.r2dbc.spi.Statement#bind(int, Object) + */ + static BindMarkersFactory anonymous(String placeholder) { + Assert.hasText(placeholder, "Placeholder must not be empty!"); + return new BindMarkersFactory() { + + @Override + public BindMarkers create() { + return new AnonymousBindMarkers(placeholder); + } + + @Override + public boolean identifiablePlaceholders() { + return false; + } + }; + } + + /** + * Create named {@link BindMarkers} using identifiers to bind parameters. + * Named bind markers can support {@link BindMarkers#next(String) name hints}. + * If no {@link BindMarkers#next(String) hint} is given, named bind markers can + * use a counter or a random value source to generate unique bind markers. + * Allows customization of the bind marker placeholder {@code prefix} and + * {@code namePrefix} to represent the bind marker as placeholder within + * the query. + * @param prefix bind parameter prefix that is included in + * {@link BindMarker#getPlaceholder()} but not the actual identifier + * @param namePrefix prefix for bind marker name that is included in + * {@link BindMarker#getPlaceholder()} and the actual identifier + * @param maxLength maximal length of parameter names when using name hints + * @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith} + * @see io.r2dbc.spi.Statement#bindNull(String, Class) + * @see io.r2dbc.spi.Statement#bind(String, Object) + */ + static BindMarkersFactory named(String prefix, String namePrefix, int maxLength) { + return named(prefix, namePrefix, maxLength, Function.identity()); + } + + /** + * Create named {@link BindMarkers} using identifiers to bind parameters. + * Named bind markers support {@link BindMarkers#next(String) name hints}. + * If no {@link BindMarkers#next(String) hint} is given, named bind markers + * can use a counter or a random value source to generate unique bind markers. + * @param prefix bind parameter prefix that is included in + * {@link BindMarker#getPlaceholder()} but not the actual identifier + * @param namePrefix prefix for bind marker name that is included in + * {@link BindMarker#getPlaceholder()} and the actual identifier + * @param maxLength maximal length of parameter names when using name hints + * @param hintFilterFunction filter {@link Function} to consider + * database-specific limitations in bind marker/variable names such as ASCII chars only + * @return a {@link BindMarkersFactory} using {@code prefix} and {@code beginWith} + * @see io.r2dbc.spi.Statement#bindNull(String, Class) + * @see io.r2dbc.spi.Statement#bind(String, Object) + */ + static BindMarkersFactory named(String prefix, String namePrefix, int maxLength, + Function hintFilterFunction) { + Assert.notNull(prefix, "Prefix must not be null"); + Assert.notNull(namePrefix, "Index prefix must not be null"); + Assert.notNull(hintFilterFunction, "Hint filter function must not be null"); + return () -> new NamedBindMarkers(prefix, namePrefix, maxLength, hintFilterFunction); + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactoryResolver.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactoryResolver.java new file mode 100644 index 000000000000..18ada18f34ee --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindMarkersFactoryResolver.java @@ -0,0 +1,181 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; + +import org.springframework.core.io.support.SpringFactoriesLoader; +import org.springframework.dao.NonTransientDataAccessException; +import org.springframework.lang.Nullable; +import org.springframework.util.LinkedCaseInsensitiveMap; + +/** + * Resolves a {@link BindMarkersFactory} from a {@link ConnectionFactory} using + * {@link BindMarkerFactoryProvider}. Dialect resolution uses Spring's + * {@link SpringFactoriesLoader spring.factories} to determine available extensions. + * + * @author Mark Paluch + * @since 5.3 + * @see BindMarkersFactory + * @see SpringFactoriesLoader + */ +public final class BindMarkersFactoryResolver { + + private static final List DETECTORS = SpringFactoriesLoader.loadFactories( + BindMarkerFactoryProvider.class, BindMarkersFactoryResolver.class.getClassLoader()); + + + /** + * Retrieve a {@link BindMarkersFactory} by inspecting {@link ConnectionFactory} and + * its metadata. + * + * @param connectionFactory the connection factory to inspect + * @return the resolved {@link BindMarkersFactory} + * @throws NoBindMarkersFactoryException if no {@link BindMarkersFactory} can be + * resolved + */ + public static BindMarkersFactory resolve(ConnectionFactory connectionFactory) { + + for (BindMarkerFactoryProvider detector : DETECTORS) { + BindMarkersFactory bindMarkersFactory = detector.getBindMarkers( + connectionFactory); + if (bindMarkersFactory != null) { + return bindMarkersFactory; + } + } + + throw new NoBindMarkersFactoryException( + String.format("Cannot determine a BindMarkersFactory for %s using %s", + connectionFactory.getMetadata().getName(), connectionFactory)); + } + + + // utility constructor. + private BindMarkersFactoryResolver() { + } + + + /** + * SPI to extend Spring's default R2DBC BindMarkersFactory discovery mechanism. + * Implementations of this interface are discovered through Spring's + * {@link SpringFactoriesLoader} mechanism. + * @see SpringFactoriesLoader + */ + @FunctionalInterface + public interface BindMarkerFactoryProvider { + + /** + * Returns a {@link BindMarkersFactory} for a {@link ConnectionFactory}. + * + * @param connectionFactory the connection factory to be used with the + * {@link BindMarkersFactory}. + * @return the {@link BindMarkersFactory} if the {@link BindMarkerFactoryProvider} + * can provide a bind marker factory object, otherwise {@code null} + */ + @Nullable + BindMarkersFactory getBindMarkers(ConnectionFactory connectionFactory); + + } + + + /** + * Exception thrown when {@link BindMarkersFactoryResolver} cannot resolve a + * {@link BindMarkersFactory}. + */ + @SuppressWarnings("serial") + public static class NoBindMarkersFactoryException + extends NonTransientDataAccessException { + + /** + * Constructor for NoBindMarkersFactoryException. + * + * @param msg the detail message + */ + public NoBindMarkersFactoryException(String msg) { + super(msg); + } + + } + + + /** + * Built-in bind maker factories. Used typically as last {@link BindMarkerFactoryProvider} + * when other providers register with a higher precedence. + * @see org.springframework.core.Ordered + * @see org.springframework.core.annotation.AnnotationAwareOrderComparator + */ + static class BuiltInBindMarkersFactoryProvider implements BindMarkerFactoryProvider { + + private static final Map BUILTIN = new LinkedCaseInsensitiveMap<>( + Locale.ENGLISH); + + static { + BUILTIN.put("H2", BindMarkersFactory.indexed("$", 1)); + BUILTIN.put("Microsoft SQL Server", BindMarkersFactory.named("@", "P", 32, + BuiltInBindMarkersFactoryProvider::filterBindMarker)); + BUILTIN.put("MySQL", BindMarkersFactory.anonymous("?")); + BUILTIN.put("MariaDB", BindMarkersFactory.anonymous("?")); + BUILTIN.put("PostgreSQL", BindMarkersFactory.indexed("$", 1)); + } + + + @Override + public BindMarkersFactory getBindMarkers(ConnectionFactory connectionFactory) { + ConnectionFactoryMetadata metadata = connectionFactory.getMetadata(); + BindMarkersFactory r2dbcDialect = BUILTIN.get(metadata.getName()); + + if (r2dbcDialect != null) { + return r2dbcDialect; + } + + + for (String it : BUILTIN.keySet()) { + if (metadata.getName().contains(it)) { + return BUILTIN.get(it); + } + } + + return null; + } + + private static String filterBindMarker(CharSequence input) { + StringBuilder builder = new StringBuilder(); + + for (int i = 0; i < input.length(); i++) { + + char ch = input.charAt(i); + // ascii letter or digit + if (Character.isLetterOrDigit(ch) && ch < 127) { + builder.append(ch); + } + } + + if (builder.length() == 0) { + return ""; + } + + return "_" + builder.toString(); + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindTarget.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindTarget.java new file mode 100644 index 000000000000..ac93390f97d9 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/BindTarget.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +/** + * Target to apply bindings to. + * + * @author Mark Paluch + * @since 5.3 + * @see io.r2dbc.spi.Statement#bind + * @see io.r2dbc.spi.Statement#bindNull + */ +public interface BindTarget { + + /** + * Bind a value. + * @param identifier the identifier to bind to + * @param value the value to bind + */ + void bind(String identifier, Object value); + + /** + * Bind a value to an index. Indexes are zero-based. + * @param index the index to bind to + * @param value the value to bind + */ + void bind(int index, Object value); + + /** + * Bind a {@code null} value. + * @param identifier the identifier to bind to + * @param type the type of {@code null} value + */ + void bindNull(String identifier, Class type); + + /** + * Bind a {@code null} value. + * @param index the index to bind to + * @param type the type of {@code null} value + */ + void bindNull(int index, Class type); + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/Bindings.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/Bindings.java new file mode 100644 index 000000000000..b430d96c0f63 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/Bindings.java @@ -0,0 +1,262 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Spliterator; +import java.util.function.Consumer; + +import io.r2dbc.spi.Statement; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Value object representing value and {@code null} bindings + * for a {@link Statement} using {@link BindMarkers}. + * Bindings are typically immutable. + * + * @author Mark Paluch + * @since 5.3 + */ +public class Bindings implements Iterable { + + private static final Bindings EMPTY = new Bindings(); + + private final Map bindings; + + + /** + * Create empty {@link Bindings}. + */ + public Bindings() { + this.bindings = Collections.emptyMap(); + } + + /** + * Create {@link Bindings} from a {@link Map}. + * @param bindings must not be {@code null} + */ + public Bindings(Collection bindings) { + Assert.notNull(bindings, "Bindings must not be null"); + Map mapping = new LinkedHashMap<>(bindings.size()); + bindings.forEach(binding -> mapping.put(binding.getBindMarker(), binding)); + this.bindings = mapping; + } + + Bindings(Map bindings) { + this.bindings = bindings; + } + + + /** + * Create a new, empty {@link Bindings} object. + * + * @return a new, empty {@link Bindings} object. + */ + public static Bindings empty() { + return EMPTY; + } + + + protected Map getBindings() { + return this.bindings; + } + + /** + * Merge this bindings with an other {@link Bindings} object and create a new merged + * {@link Bindings} object. + * @param left the left object to merge with + * @param right the right object to merge with + * @return a new, merged {@link Bindings} object + */ + public static Bindings merge(Bindings left, Bindings right) { + Assert.notNull(left, "Left side Bindings must not be null"); + Assert.notNull(right, "Right side Bindings must not be null"); + List result = new ArrayList<>( + left.getBindings().size() + right.getBindings().size()); + result.addAll(left.getBindings().values()); + result.addAll(right.getBindings().values()); + return new Bindings(result); + } + + /** + * Merge this bindings with an other {@link Bindings} object and create a new merged + * {@link Bindings} object. + * @param other the object to merge with + * @return a new, merged {@link Bindings} object + */ + public Bindings and(Bindings other) { + return merge(this, other); + } + + /** + * Apply the bindings to a {@link BindTarget}. + * @param bindTarget the target to apply bindings to + */ + public void apply(BindTarget bindTarget) { + Assert.notNull(bindTarget, "BindTarget must not be null"); + this.bindings.forEach((marker, binding) -> binding.apply(bindTarget)); + } + + /** + * Perform the given action for each binding of this {@link Bindings} until all + * bindings have been processed or the action throws an exception. Actions are + * performed in the order of iteration (if an iteration order is specified). + * Exceptions thrown by the action are relayed to the + * @param action the action to be performed for each {@link Binding} + */ + public void forEach(Consumer action) { + this.bindings.forEach((marker, binding) -> action.accept(binding)); + } + + @Override + public Iterator iterator() { + return this.bindings.values().iterator(); + } + + @Override + public Spliterator spliterator() { + return this.bindings.values().spliterator(); + } + + + /** + * Base class for value objects representing a value or a {@code NULL} binding. + */ + public abstract static class Binding { + + private final BindMarker marker; + + protected Binding(BindMarker marker) { + this.marker = marker; + } + + /** + * Return the associated {@link BindMarker}. + * @return the associated {@link BindMarker}. + */ + public BindMarker getBindMarker() { + return this.marker; + } + + /** + * Return whether the binding has a value associated with it. + * @return {@code true} if there is a value present, otherwise {@code false} + * for a {@code NULL} binding. + */ + public abstract boolean hasValue(); + + /** + * Return whether the binding is empty. + * @return {@code true} if this is is a {@code NULL} binding + */ + public boolean isNull() { + return !hasValue(); + } + + /** + * Return the binding value. + * @return value of this binding. Can be {@code null} + * if this is a {@code NULL} binding. + */ + @Nullable + public abstract Object getValue(); + + /** + * Apply the binding to a {@link BindTarget}. + * @param bindTarget the target to apply bindings to + */ + public abstract void apply(BindTarget bindTarget); + + } + + + /** + * Value binding. + */ + static class ValueBinding extends Binding { + + private final Object value; + + + ValueBinding(BindMarker marker, Object value) { + super(marker); + this.value = value; + } + + + @Override + public boolean hasValue() { + return true; + } + + @Override + public Object getValue() { + return this.value; + } + + @Override + public void apply(BindTarget bindTarget) { + getBindMarker().bind(bindTarget, getValue()); + } + + } + + /** + * {@code NULL} binding. + */ + static class NullBinding extends Binding { + + private final Class valueType; + + + NullBinding(BindMarker marker, Class valueType) { + super(marker); + this.valueType = valueType; + } + + + @Override + public boolean hasValue() { + return false; + } + + @Override + @Nullable + public Object getValue() { + return null; + } + + public Class getValueType() { + return this.valueType; + } + + @Override + public void apply(BindTarget bindTarget) { + getBindMarker().bindNull(bindTarget, getValueType()); + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/IndexedBindMarkers.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/IndexedBindMarkers.java new file mode 100644 index 000000000000..c3cfc6ffcfac --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/IndexedBindMarkers.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Index-based bind marker. This implementation creates indexed bind + * markers using a numeric index and an optional prefix for bind markers + * to be represented within the query string. + * @author Mark Paluch + * @author Jens Schauder + * @since 5.3 + */ +class IndexedBindMarkers implements BindMarkers { + + private static final AtomicIntegerFieldUpdater COUNTER_INCREMENTER = AtomicIntegerFieldUpdater + .newUpdater(IndexedBindMarkers.class, "counter"); + + + private final int offset; + + private final String prefix; + + // access via COUNTER_INCREMENTER + @SuppressWarnings("unused") + private volatile int counter; + + + /** + * Create a new {@link IndexedBindMarker} instance given {@code prefix} and {@code beginWith}. + * @param prefix bind parameter prefix + * @param beginWith the first index to use + */ + IndexedBindMarkers(String prefix, int beginWith) { + this.counter = 0; + this.prefix = prefix; + this.offset = beginWith; + } + + + @Override + public BindMarker next() { + int index = COUNTER_INCREMENTER.getAndIncrement(this); + return new IndexedBindMarker(this.prefix + "" + (index + this.offset), index); + } + + /** + * A single indexed bind marker. + * @author Mark Paluch + */ + static class IndexedBindMarker implements BindMarker { + + private final String placeholder; + + private final int index; + + + IndexedBindMarker(String placeholder, int index) { + this.placeholder = placeholder; + this.index = index; + } + + + @Override + public String getPlaceholder() { + return this.placeholder; + } + + @Override + public void bind(BindTarget target, Object value) { + target.bind(this.index, value); + } + + @Override + public void bindNull(BindTarget target, Class valueType) { + target.bindNull(this.index, valueType); + } + + public int getIndex() { + return this.index; + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/MutableBindings.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/MutableBindings.java new file mode 100644 index 000000000000..82d8ac92e7c9 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/MutableBindings.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.LinkedHashMap; + +import io.r2dbc.spi.Statement; + +import org.springframework.util.Assert; + +/** + * Mutable extension to {@link Bindings} for Value and {@code null} bindings + * for a {@link Statement} using {@link BindMarkers}. + * + * @author Mark Paluch + * @since 5.3 + */ +public class MutableBindings extends Bindings { + + private final BindMarkers markers; + + + /** + * Create new {@link MutableBindings}. + * @param markers must not be {@code null}. + */ + public MutableBindings(BindMarkers markers) { + super(new LinkedHashMap<>()); + Assert.notNull(markers, "BindMarkers must not be null"); + this.markers = markers; + } + + + /** + * Obtain the next {@link BindMarker}. + * Increments {@link BindMarkers} state + * @return the next {@link BindMarker} + */ + public BindMarker nextMarker() { + return this.markers.next(); + } + + /** + * Obtain the next {@link BindMarker} with a name {@code hint}. + * Increments {@link BindMarkers} state. + * @param hint name hint + * @return the next {@link BindMarker} + */ + public BindMarker nextMarker(String hint) { + return this.markers.next(hint); + } + + /** + * Bind a value to {@link BindMarker}. + * @param marker must not be {@code null} + * @param value must not be {@code null} + */ + public MutableBindings bind(BindMarker marker, Object value) { + Assert.notNull(marker, "BindMarker must not be null"); + Assert.notNull(value, "Value must not be null"); + getBindings().put(marker, new ValueBinding(marker, value)); + return this; + } + + /** + * Bind a value and return the related {@link BindMarker}. + * Increments {@link BindMarkers} state. + * @param value must not be {@code null} + */ + public BindMarker bind(Object value) { + Assert.notNull(value, "Value must not be null"); + BindMarker marker = nextMarker(); + getBindings().put(marker, new ValueBinding(marker, value)); + return marker; + } + + /** + * Bind a {@code NULL} value to {@link BindMarker}. + * @param marker must not be {@code null} + * @param valueType must not be {@code null} + */ + public MutableBindings bindNull(BindMarker marker, Class valueType) { + Assert.notNull(marker, "BindMarker must not be null"); + Assert.notNull(valueType, "Value type must not be null"); + getBindings().put(marker, new NullBinding(marker, valueType)); + return this; + } + + /** + * Bind a {@code NULL} value and return the related {@link BindMarker}. + * Increments {@link BindMarkers} state. + * @param valueType must not be {@code null} + */ + public BindMarker bindNull(Class valueType) { + Assert.notNull(valueType, "Value type must not be null"); + BindMarker marker = nextMarker(); + getBindings().put(marker, new NullBinding(marker, valueType)); + return marker; + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/NamedBindMarkers.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/NamedBindMarkers.java new file mode 100644 index 000000000000..fab7caf028fa --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/NamedBindMarkers.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.Function; + +import org.springframework.util.Assert; + +/** + * Name-based bind markers. + * + * @author Mark Paluch + * @since 5.3 + */ +class NamedBindMarkers implements BindMarkers { + + private static final AtomicIntegerFieldUpdater COUNTER_INCREMENTER = AtomicIntegerFieldUpdater + .newUpdater(NamedBindMarkers.class, "counter"); + + + private final String prefix; + + private final String namePrefix; + + private final int nameLimit; + + private final Function hintFilterFunction; + + // access via COUNTER_INCREMENTER + @SuppressWarnings("unused") + private volatile int counter; + + + NamedBindMarkers(String prefix, String namePrefix, int nameLimit, Function hintFilterFunction) { + this.prefix = prefix; + this.namePrefix = namePrefix; + this.nameLimit = nameLimit; + this.hintFilterFunction = hintFilterFunction; + } + + + @Override + public BindMarker next() { + String name = nextName(); + return new NamedBindMarker(this.prefix + name, name); + } + + @Override + public BindMarker next(String hint) { + Assert.notNull(hint, "Name hint must not be null"); + String name = nextName() + this.hintFilterFunction.apply(hint); + + if (name.length() > this.nameLimit) { + name = name.substring(0, this.nameLimit); + } + + return new NamedBindMarker(this.prefix + name, name); + } + + private String nextName() { + int index = COUNTER_INCREMENTER.getAndIncrement(this); + return this.namePrefix + index; + } + + + /** + * A single named bind marker. + */ + static class NamedBindMarker implements BindMarker { + + private final String placeholder; + + private final String identifier; + + NamedBindMarker(String placeholder, String identifier) { + + this.placeholder = placeholder; + this.identifier = identifier; + } + + @Override + public String getPlaceholder() { + return this.placeholder; + } + + @Override + public void bind(BindTarget target, Object value) { + target.bind(this.identifier, value); + } + + @Override + public void bindNull(BindTarget target, Class valueType) { + target.bindNull(this.identifier, valueType); + } + + } + +} diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/package-info.java new file mode 100644 index 000000000000..e95acdc21d77 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/binding/package-info.java @@ -0,0 +1,9 @@ +/** + * Classes providing an abstraction over SQL bind markers. + */ +@NonNullApi +@NonNullFields +package org.springframework.r2dbc.core.binding; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/package-info.java new file mode 100644 index 000000000000..19d338ecf1f4 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/core/package-info.java @@ -0,0 +1,6 @@ +/** + * Core domain types around DatabaseClient. + */ +@org.springframework.lang.NonNullApi +@org.springframework.lang.NonNullFields +package org.springframework.r2dbc.core; diff --git a/spring-r2dbc/src/main/java/org/springframework/r2dbc/package-info.java b/spring-r2dbc/src/main/java/org/springframework/r2dbc/package-info.java new file mode 100644 index 000000000000..c892eec7f380 --- /dev/null +++ b/spring-r2dbc/src/main/java/org/springframework/r2dbc/package-info.java @@ -0,0 +1,21 @@ +/** + * The classes in this package make R2DBC easier to use and + * reduce the likelihood of common errors. In particular, they: + *

    + *
  • Simplify error handling, avoiding the need for resource management + * blocks in application code. + *
  • Present exceptions to application code in a generic hierarchy of + * unchecked exceptions, enabling applications to catch data access + * exceptions without being dependent on R2DBC, and to ignore fatal + * exceptions there is no value in catching. + *
  • Allow the implementation of error handling to be modified + * to target different RDBMSes without introducing proprietary + * dependencies into application code. + *
+ */ +@NonNullApi +@NonNullFields +package org.springframework.r2dbc; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensions.kt b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensions.kt new file mode 100644 index 000000000000..0528f2c76f95 --- /dev/null +++ b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensions.kt @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core + +import kotlinx.coroutines.reactive.awaitFirstOrNull + +/** + * Coroutines variant of [DatabaseClient.GenericExecuteSpec.then]. + * + * @author Sebastien Deleuze + */ +suspend fun DatabaseClient.GenericExecuteSpec.await() { + then().awaitFirstOrNull() +} + +/** + * Extension for [DatabaseClient.BindSpec.bind] providing a variant leveraging reified type parameters + * + * @author Mark Paluch + * @author Ibanga Enoobong Ime + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER") +inline fun DatabaseClient.GenericExecuteSpec.bind(index: Int, value: T?) = bind(index, Parameter.fromOrEmpty(value, T::class.java)) + +/** + * Extension for [DatabaseClient.BindSpec.bind] providing a variant leveraging reified type parameters + * + * @author Mark Paluch + * @author Ibanga Enoobong Ime + */ +@Suppress("EXTENSION_SHADOWED_BY_MEMBER") +inline fun DatabaseClient.GenericExecuteSpec.bind(name: String, value: T?) = bind(name, Parameter.fromOrEmpty(value, T::class.java)) diff --git a/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensions.kt b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensions.kt new file mode 100644 index 000000000000..d79d954c0706 --- /dev/null +++ b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensions.kt @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.r2dbc.core + +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.reactive.asFlow +import kotlinx.coroutines.reactive.awaitFirstOrNull +import org.springframework.dao.EmptyResultDataAccessException + +/** + * Non-nullable Coroutines variant of [RowsFetchSpec.one]. + * + * @author Sebastien Deleuze + */ +suspend fun RowsFetchSpec.awaitOne(): T { + return one().awaitFirstOrNull() ?: throw EmptyResultDataAccessException(1) +} + +/** + * Nullable Coroutines variant of [RowsFetchSpec.one]. + * + * @author Sebastien Deleuze + */ +suspend fun RowsFetchSpec.awaitOneOrNull(): T? = + one().awaitFirstOrNull() + +/** + * Non-nullable Coroutines variant of [RowsFetchSpec.first]. + * + * @author Sebastien Deleuze + */ +suspend fun RowsFetchSpec.awaitFirst(): T { + return first().awaitFirstOrNull() ?: throw EmptyResultDataAccessException(1) +} + +/** + * Nullable Coroutines variant of [RowsFetchSpec.first]. + * + * @author Sebastien Deleuze + */ +suspend fun RowsFetchSpec.awaitFirstOrNull(): T? = + first().awaitFirstOrNull() + +/** + * Coroutines [Flow] variant of [RowsFetchSpec.all]. + * + * @author Sebastien Deleuze + */ +fun RowsFetchSpec.flow(): Flow = all().asFlow() diff --git a/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensions.kt b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensions.kt new file mode 100644 index 000000000000..576a5e1a0900 --- /dev/null +++ b/spring-r2dbc/src/main/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensions.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core + +import kotlinx.coroutines.reactive.awaitSingle + +/** + * Coroutines variant of [UpdatedRowsFetchSpec.rowsUpdated]. + * + * @author Fred Montariol + */ +suspend fun UpdatedRowsFetchSpec.awaitRowsUpdated(): Int = + rowsUpdated().awaitSingle() diff --git a/spring-r2dbc/src/main/resources/META-INF/spring.factories b/spring-r2dbc/src/main/resources/META-INF/spring.factories new file mode 100644 index 000000000000..c999996af4cb --- /dev/null +++ b/spring-r2dbc/src/main/resources/META-INF/spring.factories @@ -0,0 +1 @@ +org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver$BindMarkerFactoryProvider=org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver.BuiltInBindMarkersFactoryProvider diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/ConnectionFactoryUtilsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/ConnectionFactoryUtilsUnitTests.java new file mode 100644 index 000000000000..fd976f38f3a8 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/ConnectionFactoryUtilsUnitTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2019-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.spi.R2dbcBadGrammarException; +import io.r2dbc.spi.R2dbcDataIntegrityViolationException; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.r2dbc.spi.R2dbcPermissionDeniedException; +import io.r2dbc.spi.R2dbcRollbackException; +import io.r2dbc.spi.R2dbcTimeoutException; +import io.r2dbc.spi.R2dbcTransientResourceException; +import org.junit.jupiter.api.Test; + +import org.springframework.dao.ConcurrencyFailureException; +import org.springframework.dao.DataAccessResourceFailureException; +import org.springframework.dao.DataIntegrityViolationException; +import org.springframework.dao.PermissionDeniedDataAccessException; +import org.springframework.dao.QueryTimeoutException; +import org.springframework.dao.TransientDataAccessResourceException; +import org.springframework.r2dbc.BadSqlGrammarException; +import org.springframework.r2dbc.UncategorizedR2dbcException; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ConnectionFactoryUtils}. + * + * @author Mark Paluch + */ +public class ConnectionFactoryUtilsUnitTests { + + @Test + public void shouldTranslateTransientResourceException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcTransientResourceException("")); + assertThat(exception).isInstanceOf(TransientDataAccessResourceException.class); + } + + @Test + public void shouldTranslateRollbackException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcRollbackException()); + assertThat(exception).isInstanceOf(ConcurrencyFailureException.class); + } + + @Test + public void shouldTranslateTimeoutException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcTimeoutException()); + assertThat(exception).isInstanceOf(QueryTimeoutException.class); + } + + @Test + public void shouldNotTranslateUnknownExceptions() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new MyTransientExceptions()); + assertThat(exception).isInstanceOf(UncategorizedR2dbcException.class); + } + + @Test + public void shouldTranslateNonTransientResourceException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcNonTransientResourceException()); + assertThat(exception).isInstanceOf(DataAccessResourceFailureException.class); + } + + @Test + public void shouldTranslateIntegrityViolationException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcDataIntegrityViolationException()); + assertThat(exception).isInstanceOf(DataIntegrityViolationException.class); + } + + @Test + public void shouldTranslatePermissionDeniedException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcPermissionDeniedException()); + assertThat(exception).isInstanceOf(PermissionDeniedDataAccessException.class); + } + + @Test + public void shouldTranslateBadSqlGrammarException() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("", "", + new R2dbcBadGrammarException()); + assertThat(exception).isInstanceOf(BadSqlGrammarException.class); + } + + @Test + public void messageGeneration() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK", + "SOME-SQL", new R2dbcTransientResourceException("MESSAGE")); + assertThat(exception).isInstanceOf( + TransientDataAccessResourceException.class).hasMessage( + "TASK; SQL [SOME-SQL]; MESSAGE; nested exception is io.r2dbc.spi.R2dbcTransientResourceException: MESSAGE"); + } + + @Test + public void messageGenerationNullSQL() { + Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK", null, + new R2dbcTransientResourceException("MESSAGE")); + assertThat(exception).isInstanceOf( + TransientDataAccessResourceException.class).hasMessage( + "TASK; MESSAGE; nested exception is io.r2dbc.spi.R2dbcTransientResourceException: MESSAGE"); + } + + @Test + public void messageGenerationNullMessage() { + + Exception exception = ConnectionFactoryUtils.convertR2dbcException("TASK", + "SOME-SQL", new R2dbcTransientResourceException()); + assertThat(exception).isInstanceOf( + TransientDataAccessResourceException.class).hasMessage( + "TASK; SQL [SOME-SQL]; null; nested exception is io.r2dbc.spi.R2dbcTransientResourceException"); + } + + @SuppressWarnings("serial") + private static class MyTransientExceptions extends R2dbcException { + } +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/DelegatingConnectionFactoryUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/DelegatingConnectionFactoryUnitTests.java new file mode 100644 index 000000000000..f6c7ca0fbdf5 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/DelegatingConnectionFactoryUnitTests.java @@ -0,0 +1,63 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link DelegatingConnectionFactory}. + * + * @author Mark Paluch + */ +public class DelegatingConnectionFactoryUnitTests { + + ConnectionFactory delegate = mock(ConnectionFactory.class); + + Connection connectionMock = mock(Connection.class); + + DelegatingConnectionFactory connectionFactory = new ExampleConnectionFactory( + delegate); + + @Test + public void shouldDelegateGetConnection() { + + Mono connectionMono = Mono.just(connectionMock); + when(delegate.create()).thenReturn((Mono) connectionMono); + + assertThat(connectionFactory.create()).isSameAs(connectionMono); + } + + @Test + public void shouldDelegateUnwrapWithoutImplementing() { + assertThat(connectionFactory.unwrap()).isSameAs(delegate); + } + + static class ExampleConnectionFactory extends DelegatingConnectionFactory { + + ExampleConnectionFactory(ConnectionFactory targetConnectionFactory) { + super(targetConnectionFactory); + } + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java new file mode 100644 index 000000000000..3003b1a354de --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/R2dbcTransactionManagerUnitTests.java @@ -0,0 +1,488 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import java.util.concurrent.atomic.AtomicInteger; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.R2dbcBadGrammarException; +import io.r2dbc.spi.Statement; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.transaction.CannotCreateTransactionException; +import org.springframework.transaction.IllegalTransactionStateException; +import org.springframework.transaction.TransactionDefinition; +import org.springframework.transaction.reactive.TransactionSynchronization; +import org.springframework.transaction.reactive.TransactionSynchronizationManager; +import org.springframework.transaction.reactive.TransactionalOperator; +import org.springframework.transaction.support.DefaultTransactionDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.never; +import static org.mockito.BDDMockito.reset; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link R2dbcTransactionManager}. + * + * @author Mark Paluch + */ +public class R2dbcTransactionManagerUnitTests { + + ConnectionFactory connectionFactoryMock = mock(ConnectionFactory.class); + + Connection connectionMock = mock(Connection.class); + + private R2dbcTransactionManager tm; + + @BeforeEach + public void before() { + + when(connectionFactoryMock.create()).thenReturn((Mono) Mono.just(connectionMock)); + when(connectionMock.beginTransaction()).thenReturn(Mono.empty()); + when(connectionMock.close()).thenReturn(Mono.empty()); + tm = new R2dbcTransactionManager(connectionFactoryMock); + } + + @Test + public void testSimpleTransaction() { + TestTransactionSynchronization sync = new TestTransactionSynchronization( + TransactionSynchronization.STATUS_COMMITTED); + AtomicInteger commits = new AtomicInteger(); + when(connectionMock.commitTransaction()).thenReturn( + Mono.fromRunnable(commits::incrementAndGet)); + + TransactionalOperator operator = TransactionalOperator.create(tm); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock) + .flatMap(connection -> TransactionSynchronizationManager.forCurrentTransaction() + .doOnNext(synchronizationManager -> synchronizationManager.registerSynchronization( + sync))) + .as(operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + assertThat(commits).hasValue(1); + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + + assertThat(sync.beforeCommitCalled).isTrue(); + assertThat(sync.afterCommitCalled).isTrue(); + assertThat(sync.beforeCompletionCalled).isTrue(); + assertThat(sync.afterCompletionCalled).isTrue(); + } + + @Test + public void testBeginFails() { + reset(connectionFactoryMock); + when(connectionFactoryMock.create()).thenReturn( + Mono.error(new R2dbcBadGrammarException("fail"))); + + when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setIsolationLevel(TransactionDefinition.ISOLATION_SERIALIZABLE); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectErrorSatisfies(actual -> assertThat(actual).isInstanceOf( + CannotCreateTransactionException.class).hasCauseInstanceOf( + R2dbcBadGrammarException.class)) + .verify(); + } + + @Test + public void appliesIsolationLevel() { + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + when(connectionMock.getTransactionIsolationLevel()).thenReturn( + IsolationLevel.READ_COMMITTED); + when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setIsolationLevel(TransactionDefinition.ISOLATION_SERIALIZABLE); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + verify(connectionMock).beginTransaction(); + verify(connectionMock).setTransactionIsolationLevel( + IsolationLevel.READ_COMMITTED); + verify(connectionMock).setTransactionIsolationLevel(IsolationLevel.SERIALIZABLE); + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + } + + @Test + public void doesNotSetIsolationLevelIfMatch() { + when(connectionMock.getTransactionIsolationLevel()).thenReturn( + IsolationLevel.READ_COMMITTED); + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setIsolationLevel(TransactionDefinition.ISOLATION_READ_COMMITTED); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + verify(connectionMock).beginTransaction(); + verify(connectionMock, never()).setTransactionIsolationLevel(any()); + verify(connectionMock).commitTransaction(); + } + + @Test + public void doesNotSetAutoCommitDisabled() { + when(connectionMock.isAutoCommit()).thenReturn(false); + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + verify(connectionMock).beginTransaction(); + verify(connectionMock, never()).setAutoCommit(anyBoolean()); + verify(connectionMock).commitTransaction(); + } + + @Test + public void restoresAutoCommit() { + when(connectionMock.isAutoCommit()).thenReturn(true); + when(connectionMock.setAutoCommit(anyBoolean())).thenReturn(Mono.empty()); + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + verify(connectionMock).beginTransaction(); + verify(connectionMock).setAutoCommit(false); + verify(connectionMock).setAutoCommit(true); + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + } + + @Test + public void appliesReadOnly() { + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + when(connectionMock.setTransactionIsolationLevel(any())).thenReturn(Mono.empty()); + Statement statement = mock(Statement.class); + when(connectionMock.createStatement(anyString())).thenReturn(statement); + when(statement.execute()).thenReturn(Mono.empty()); + tm.setEnforceReadOnly(true); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setReadOnly(true); + + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock).as( + operator::transactional) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).createStatement("SET TRANSACTION READ ONLY"); + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + } + + @Test + public void testCommitFails() { + when(connectionMock.commitTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail")))); + + when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); + + TransactionalOperator operator = TransactionalOperator.create(tm); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock) + .doOnNext(connection -> connection.createStatement("foo")).then() + .as(operator::transactional) + .as(StepVerifier::create) + .verifyError(IllegalTransactionStateException.class); + + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).createStatement("foo"); + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + } + + @Test + public void testRollback() { + + AtomicInteger commits = new AtomicInteger(); + when(connectionMock.commitTransaction()).thenReturn( + Mono.fromRunnable(commits::incrementAndGet)); + + AtomicInteger rollbacks = new AtomicInteger(); + when(connectionMock.rollbackTransaction()).thenReturn( + Mono.fromRunnable(rollbacks::incrementAndGet)); + + TransactionalOperator operator = TransactionalOperator.create(tm); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock) + .doOnNext(connection -> { + throw new IllegalStateException(); + }).as(operator::transactional) + .as(StepVerifier::create) + .verifyError(IllegalStateException.class); + + assertThat(commits).hasValue(0); + assertThat(rollbacks).hasValue(1); + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).rollbackTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + } + + @Test + public void testRollbackFails() { + when(connectionMock.rollbackTransaction()).thenReturn(Mono.defer(() -> Mono.error(new R2dbcBadGrammarException("Commit should fail"))), Mono.empty()); + + TransactionalOperator operator = TransactionalOperator.create(tm); + + operator.execute(reactiveTransaction -> { + + reactiveTransaction.setRollbackOnly(); + + return ConnectionFactoryUtils.getConnection(connectionFactoryMock) + .doOnNext(connection -> connection.createStatement("foo")).then(); + }).as(StepVerifier::create) + .verifyError(IllegalTransactionStateException.class); + + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).createStatement("foo"); + verify(connectionMock, never()).commitTransaction(); + verify(connectionMock).rollbackTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + } + + @Test + public void testTransactionSetRollbackOnly() { + when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); + TestTransactionSynchronization sync = new TestTransactionSynchronization( + TransactionSynchronization.STATUS_ROLLED_BACK); + + TransactionalOperator operator = TransactionalOperator.create(tm); + + operator.execute(tx -> { + + tx.setRollbackOnly(); + assertThat(tx.isNewTransaction()).isTrue(); + + return TransactionSynchronizationManager.forCurrentTransaction().doOnNext( + synchronizationManager -> { + assertThat(synchronizationManager.hasResource(connectionFactoryMock)).isTrue(); + synchronizationManager.registerSynchronization(sync); + }).then(); + }).as(StepVerifier::create) + .verifyComplete(); + + verify(connectionMock).isAutoCommit(); + verify(connectionMock).beginTransaction(); + verify(connectionMock).rollbackTransaction(); + verify(connectionMock).close(); + verifyNoMoreInteractions(connectionMock); + + assertThat(sync.beforeCommitCalled).isFalse(); + assertThat(sync.afterCommitCalled).isFalse(); + assertThat(sync.beforeCompletionCalled).isTrue(); + assertThat(sync.afterCompletionCalled).isTrue(); + } + + @Test + public void testPropagationNeverWithExistingTransaction() { + when(connectionMock.rollbackTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW); + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + operator.execute(tx1 -> { + + assertThat(tx1.isNewTransaction()).isTrue(); + + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_NEVER); + return operator.execute(tx2 -> { + + fail("Should have thrown IllegalTransactionStateException"); + return Mono.empty(); + }); + }).as(StepVerifier::create) + .verifyError(IllegalTransactionStateException.class); + + verify(connectionMock).rollbackTransaction(); + verify(connectionMock).close(); + } + + @Test + public void testPropagationSupportsAndRequiresNew() { + when(connectionMock.commitTransaction()).thenReturn(Mono.empty()); + + DefaultTransactionDefinition definition = new DefaultTransactionDefinition(); + definition.setPropagationBehavior(TransactionDefinition.PROPAGATION_SUPPORTS); + TransactionalOperator operator = TransactionalOperator.create(tm, definition); + + operator.execute(tx1 -> { + + assertThat(tx1.isNewTransaction()).isFalse(); + + DefaultTransactionDefinition innerDef = new DefaultTransactionDefinition(); + innerDef.setPropagationBehavior( + TransactionDefinition.PROPAGATION_REQUIRES_NEW); + TransactionalOperator inner = TransactionalOperator.create(tm, innerDef); + + return inner.execute(tx2 -> { + + assertThat(tx2.isNewTransaction()).isTrue(); + return Mono.empty(); + }); + }).as(StepVerifier::create) + .verifyComplete(); + + verify(connectionMock).commitTransaction(); + verify(connectionMock).close(); + } + + + private static class TestTransactionSynchronization + implements TransactionSynchronization { + + private int status; + + public boolean beforeCommitCalled; + + public boolean beforeCompletionCalled; + + public boolean afterCommitCalled; + + public boolean afterCompletionCalled; + + public Throwable afterCompletionException; + + public TestTransactionSynchronization(int status) { + this.status = status; + } + + @Override + public Mono suspend() { + return Mono.empty(); + } + + @Override + public Mono resume() { + return Mono.empty(); + } + + @Override + public Mono beforeCommit(boolean readOnly) { + if (this.status != TransactionSynchronization.STATUS_COMMITTED) { + fail("Should never be called"); + } + return Mono.fromRunnable(() -> { + assertThat(this.beforeCommitCalled).isFalse(); + this.beforeCommitCalled = true; + }); + } + + @Override + public Mono beforeCompletion() { + return Mono.fromRunnable(() -> { + assertThat(this.beforeCompletionCalled).isFalse(); + this.beforeCompletionCalled = true; + }); + } + + @Override + public Mono afterCommit() { + if (this.status != TransactionSynchronization.STATUS_COMMITTED) { + fail("Should never be called"); + } + return Mono.fromRunnable(() -> { + assertThat(this.afterCommitCalled).isFalse(); + this.afterCommitCalled = true; + }); + } + + @Override + public Mono afterCompletion(int status) { + try { + return Mono.fromRunnable(() -> doAfterCompletion(status)); + } + catch (Throwable ex) { + this.afterCompletionException = ex; + } + + return Mono.empty(); + } + + protected void doAfterCompletion(int status) { + assertThat(this.afterCompletionCalled).isFalse(); + this.afterCompletionCalled = true; + assertThat(status).isEqualTo(this.status); + } + + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/SingleConnectionFactoryUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/SingleConnectionFactoryUnitTests.java new file mode 100644 index 000000000000..dd8344935d72 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/SingleConnectionFactoryUnitTests.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import io.r2dbc.h2.H2Connection; +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.r2dbc.spi.Wrapped; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.never; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link SingleConnectionFactory}. + * + * @author Mark Paluch + */ +public class SingleConnectionFactoryUnitTests { + + @Test + public void shouldAllocateSameConnection() { + SingleConnectionFactory factory = new SingleConnectionFactory( + "r2dbc:h2:mem:///foo", false); + + Mono cf1 = factory.create(); + Mono cf2 = factory.create(); + + Connection c1 = cf1.block(); + Connection c2 = cf2.block(); + + assertThat(c1).isSameAs(c2); + factory.destroy(); + } + + @Test + public void shouldApplyAutoCommit() { + SingleConnectionFactory factory = new SingleConnectionFactory( + "r2dbc:h2:mem:///foo", false); + factory.setAutoCommit(false); + + factory.create().as(StepVerifier::create) + .consumeNextWith(actual -> assertThat(actual.isAutoCommit()).isFalse()) + .verifyComplete(); + + factory.setAutoCommit(true); + + factory.create().as(StepVerifier::create) + .consumeNextWith(actual -> assertThat(actual.isAutoCommit()).isTrue()) + .verifyComplete(); + + factory.destroy(); + } + + @Test + public void shouldSuppressClose() { + SingleConnectionFactory factory = new SingleConnectionFactory( + "r2dbc:h2:mem:///foo", true); + + Connection connection = factory.create().block(); + + StepVerifier.create(connection.close()).verifyComplete(); + assertThat(connection).isInstanceOf(Wrapped.class); + assertThat(((Wrapped) connection).unwrap()).isInstanceOf(H2Connection.class); + + StepVerifier.create( + connection.setTransactionIsolationLevel(IsolationLevel.READ_COMMITTED)) + .verifyComplete(); + factory.destroy(); + } + + @Test + public void shouldNotSuppressClose() { + SingleConnectionFactory factory = new SingleConnectionFactory( + "r2dbc:h2:mem:///foo", false); + + Connection connection = factory.create().block(); + + StepVerifier.create(connection.close()).verifyComplete(); + + StepVerifier.create(connection.setTransactionIsolationLevel( + IsolationLevel.READ_COMMITTED)).verifyError( + R2dbcNonTransientResourceException.class); + factory.destroy(); + } + + @Test + public void releaseConnectionShouldNotCloseConnection() { + Connection connectionMock = mock(Connection.class); + ConnectionFactoryMetadata metadata = mock(ConnectionFactoryMetadata.class); + + SingleConnectionFactory factory = new SingleConnectionFactory( + connectionMock, metadata, true); + + Connection connection = factory.create().block(); + + ConnectionFactoryUtils.releaseConnection(connection, factory) + .as(StepVerifier::create) + .verifyComplete(); + + verify(connectionMock, never()).close(); + } + + @Test + public void releaseConnectionShouldCloseUnrelatedConnection() { + Connection connectionMock = mock(Connection.class); + Connection otherConnection = mock(Connection.class); + ConnectionFactoryMetadata metadata = mock(ConnectionFactoryMetadata.class); + when(otherConnection.close()).thenReturn(Mono.empty()); + + SingleConnectionFactory factory = new SingleConnectionFactory( + connectionMock, metadata, false); + + factory.create().as(StepVerifier::create).expectNextCount(1).verifyComplete(); + + ConnectionFactoryUtils.releaseConnection(otherConnection, factory) + .as(StepVerifier::create) + .verifyComplete(); + + verify(otherConnection).close(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxyUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxyUnitTests.java new file mode 100644 index 000000000000..0cc83b51db69 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/TransactionAwareConnectionFactoryProxyUnitTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection; + +import java.util.concurrent.atomic.AtomicReference; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Wrapped; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.transaction.reactive.TransactionalOperator; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.times; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoInteractions; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link TransactionAwareConnectionFactoryProxy}. + * + * @author Mark Paluch + * @author Christoph Strobl + */ +public class TransactionAwareConnectionFactoryProxyUnitTests { + + ConnectionFactory connectionFactoryMock = mock(ConnectionFactory.class); + + Connection connectionMock1 = mock(Connection.class); + + Connection connectionMock2 = mock(Connection.class); + + Connection connectionMock3 = mock(Connection.class); + + R2dbcTransactionManager tm; + + @BeforeEach + public void before() { + when(connectionFactoryMock.create()).thenReturn((Mono) Mono.just(connectionMock1), + (Mono) Mono.just(connectionMock2), (Mono) Mono.just(connectionMock3)); + tm = new R2dbcTransactionManager(connectionFactoryMock); + } + + @Test + public void createShouldWrapConnection() { + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .as(StepVerifier::create) + .consumeNextWith(connection -> assertThat(connection).isInstanceOf(Wrapped.class)) + .verifyComplete(); + } + + @Test + public void unwrapShouldReturnTargetConnection() { + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Wrapped.class::cast).as(StepVerifier::create) + .consumeNextWith(wrapped -> assertThat(wrapped.unwrap()).isEqualTo(connectionMock1)) + .verifyComplete(); + } + + @Test + public void unwrapShouldReturnTargetConnectionEvenWhenClosed() { + when(connectionMock1.close()).thenReturn(Mono.empty()); + + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Connection.class::cast).flatMap( + connection -> Mono.from(connection.close()).then(Mono.just(connection))).as( + StepVerifier::create) + .consumeNextWith(wrapped -> assertThat(((Wrapped) wrapped).unwrap()).isEqualTo(connectionMock1)) + .verifyComplete(); + } + + @Test + public void getTargetConnectionShouldReturnTargetConnection() { + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Wrapped.class::cast).as(StepVerifier::create) + .consumeNextWith(wrapped -> assertThat(wrapped.unwrap()).isEqualTo(connectionMock1)) + .verifyComplete(); + } + + @Test + public void getMetadataShouldThrowsErrorEvenWhenClosed() { + when(connectionMock1.close()).thenReturn(Mono.empty()); + + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Connection.class::cast).flatMap( + connection -> Mono.from(connection.close()) + .then(Mono.just(connection))).as(StepVerifier::create) + .consumeNextWith(connection -> assertThatIllegalStateException().isThrownBy( + connection::getMetadata)).verifyComplete(); + } + + @Test + public void hashCodeShouldReturnProxyHash() { + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Connection.class::cast).as(StepVerifier::create) + .consumeNextWith(connection -> assertThat(connection.hashCode()).isEqualTo( + System.identityHashCode(connection))).verifyComplete(); + } + + @Test + public void equalsShouldCompareCorrectly() { + new TransactionAwareConnectionFactoryProxy(connectionFactoryMock).create() + .map(Connection.class::cast).as(StepVerifier::create) + .consumeNextWith(connection -> { + assertThat(connection.equals(connection)).isTrue(); + assertThat(connection.equals(connectionMock1)).isFalse(); + }).verifyComplete(); + } + + @Test + public void shouldEmitBoundConnection() { + when(connectionMock1.beginTransaction()).thenReturn(Mono.empty()); + when(connectionMock1.commitTransaction()).thenReturn(Mono.empty()); + when(connectionMock1.close()).thenReturn(Mono.empty()); + + TransactionalOperator rxtx = TransactionalOperator.create(tm); + AtomicReference transactionalConnection = new AtomicReference<>(); + + TransactionAwareConnectionFactoryProxy proxyCf = new TransactionAwareConnectionFactoryProxy( + connectionFactoryMock); + + ConnectionFactoryUtils.getConnection(connectionFactoryMock) + .doOnNext(transactionalConnection::set).flatMap(connection -> proxyCf.create() + .doOnNext(wrappedConnection -> assertThat(((Wrapped) wrappedConnection).unwrap()).isSameAs(connection))) + .as(rxtx::transactional) + .flatMapMany(Connection::close) + .as(StepVerifier::create) + .verifyComplete(); + + verifyNoInteractions(connectionMock2); + verifyNoInteractions(connectionMock3); + verify(connectionFactoryMock, times(1)).create(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/AbstractDatabaseInitializationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/AbstractDatabaseInitializationTests.java new file mode 100644 index 000000000000..dd8b13c2c2be --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/AbstractDatabaseInitializationTests.java @@ -0,0 +1,136 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import org.springframework.core.io.ClassRelativeResourceLoader; +import org.springframework.core.io.Resource; +import org.springframework.r2dbc.core.DatabaseClient; + + +/** + * Abstract test support for {@link DatabasePopulator}. + * + * @author Mark Paluch + */ +public abstract class AbstractDatabaseInitializationTests { + + ClassRelativeResourceLoader resourceLoader = new ClassRelativeResourceLoader( + getClass()); + + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + + + @Test + public void scriptWithSingleLineCommentsAndFailedDrop() { + databasePopulator.addScript(resource("db-schema-failed-drop-comments.sql")); + databasePopulator.addScript(resource("db-test-data.sql")); + databasePopulator.setIgnoreFailedDrops(true); + + runPopulator(); + + assertUsersDatabaseCreated("Heisenberg"); + } + + private void runPopulator() { + databasePopulator.populate(getConnectionFactory()) // + .as(StepVerifier::create) // + .verifyComplete(); + } + + @Test + public void scriptWithStandardEscapedLiteral() { + databasePopulator.addScript(defaultSchema()); + databasePopulator.addScript(resource("db-test-data-escaped-literal.sql")); + + runPopulator(); + + assertUsersDatabaseCreated("'Heisenberg'"); + } + + @Test + public void scriptWithMySqlEscapedLiteral() { + databasePopulator.addScript(defaultSchema()); + databasePopulator.addScript(resource("db-test-data-mysql-escaped-literal.sql")); + + runPopulator(); + + assertUsersDatabaseCreated("\\$Heisenberg\\$"); + } + + @Test + public void scriptWithMultipleStatements() { + databasePopulator.addScript(defaultSchema()); + databasePopulator.addScript(resource("db-test-data-multiple.sql")); + + runPopulator(); + + assertUsersDatabaseCreated("Heisenberg", "Jesse"); + } + + @Test + public void scriptWithMultipleStatementsAndLongSeparator() { + databasePopulator.addScript(defaultSchema()); + databasePopulator.addScript(resource("db-test-data-endings.sql")); + databasePopulator.setSeparator("@@"); + + runPopulator(); + + assertUsersDatabaseCreated("Heisenberg", "Jesse"); + } + + abstract ConnectionFactory getConnectionFactory(); + + Resource resource(String path) { + return resourceLoader.getResource(path); + } + + Resource defaultSchema() { + return resource("db-schema.sql"); + } + + Resource usersSchema() { + return resource("users-schema.sql"); + } + + void assertUsersDatabaseCreated(String... lastNames) { + assertUsersDatabaseCreated(getConnectionFactory(), lastNames); + } + + void assertUsersDatabaseCreated(ConnectionFactory connectionFactory, + String... lastNames) { + + DatabaseClient client = DatabaseClient.create(connectionFactory); + + for (String lastName : lastNames) { + + client.sql("select count(0) from users where last_name = :name") // + .bind("name", lastName) // + .map((row, metadata) -> row.get(0)) // + .first() // + .map(number -> ((Number) number).intValue()) // + .as(StepVerifier::create) // + .expectNext(1).as( + "Did not find user with last name [" + lastName + "].") // + .verifyComplete(); + } + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulatorTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulatorTests.java new file mode 100644 index 000000000000..baf257a6bf7c --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/CompositeDatabasePopulatorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.util.LinkedHashSet; +import java.util.Set; + +import io.r2dbc.spi.Connection; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.times; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link CompositeDatabasePopulator}. + * + * @author Mark Paluch + */ +public class CompositeDatabasePopulatorTests { + + Connection mockedConnection = mock(Connection.class); + + DatabasePopulator mockedDatabasePopulator1 = mock(DatabasePopulator.class); + + DatabasePopulator mockedDatabasePopulator2 = mock(DatabasePopulator.class); + + + @BeforeEach + public void before() { + when(mockedDatabasePopulator1.populate(mockedConnection)).thenReturn( + Mono.empty()); + when(mockedDatabasePopulator2.populate(mockedConnection)).thenReturn( + Mono.empty()); + } + + @Test + public void addPopulators() { + CompositeDatabasePopulator populator = new CompositeDatabasePopulator(); + populator.addPopulators(mockedDatabasePopulator1, mockedDatabasePopulator2); + + populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete(); + + verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection); + verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection); + } + + @Test + public void setPopulatorsWithMultiple() { + CompositeDatabasePopulator populator = new CompositeDatabasePopulator(); + populator.setPopulators(mockedDatabasePopulator1, mockedDatabasePopulator2); // multiple + + populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete(); + + verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection); + verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection); + } + + @Test + public void setPopulatorsForOverride() { + CompositeDatabasePopulator populator = new CompositeDatabasePopulator(); + populator.setPopulators(mockedDatabasePopulator1); + populator.setPopulators(mockedDatabasePopulator2); // override + + populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete(); + + verify(mockedDatabasePopulator1, times(0)).populate(mockedConnection); + verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection); + } + + @Test + public void constructWithVarargs() { + CompositeDatabasePopulator populator = new CompositeDatabasePopulator( + mockedDatabasePopulator1, mockedDatabasePopulator2); + + populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete(); + + verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection); + verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection); + } + + @Test + public void constructWithCollection() { + Set populators = new LinkedHashSet<>(); + populators.add(mockedDatabasePopulator1); + populators.add(mockedDatabasePopulator2); + + CompositeDatabasePopulator populator = new CompositeDatabasePopulator(populators); + populator.populate(mockedConnection).as(StepVerifier::create).verifyComplete(); + + verify(mockedDatabasePopulator1, times(1)).populate(mockedConnection); + verify(mockedDatabasePopulator2, times(1)).populate(mockedConnection); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializerUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializerUnitTests.java new file mode 100644 index 000000000000..eadea3a14d1c --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ConnectionFactoryInitializerUnitTests.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.util.concurrent.atomic.AtomicBoolean; + +import io.r2dbc.spi.test.MockConnection; +import io.r2dbc.spi.test.MockConnectionFactory; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link ConnectionFactoryInitializer}. + * + * @author Mark Paluch + */ +public class ConnectionFactoryInitializerUnitTests { + + AtomicBoolean called = new AtomicBoolean(); + + DatabasePopulator populator = mock(DatabasePopulator.class); + + MockConnection connection = MockConnection.builder().build(); + + MockConnectionFactory connectionFactory = MockConnectionFactory.builder().connection( + connection).build(); + + + @Test + public void shouldInitializeConnectionFactory() { + when(populator.populate(connectionFactory)).thenReturn( + Mono. empty().doOnSubscribe(subscription -> called.set(true))); + + ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer(); + initializer.setConnectionFactory(connectionFactory); + initializer.setDatabasePopulator(populator); + + initializer.afterPropertiesSet(); + + assertThat(called).isTrue(); + } + + @Test + public void shouldCleanConnectionFactory() { + when(populator.populate(connectionFactory)).thenReturn( + Mono. empty().doOnSubscribe(subscription -> called.set(true))); + + ConnectionFactoryInitializer initializer = new ConnectionFactoryInitializer(); + initializer.setConnectionFactory(connectionFactory); + initializer.setDatabaseCleaner(populator); + + initializer.afterPropertiesSet(); + initializer.destroy(); + + assertThat(called).isTrue(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/H2DatabasePopulatorIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/H2DatabasePopulatorIntegrationTests.java new file mode 100644 index 000000000000..643beb6301fd --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/H2DatabasePopulatorIntegrationTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.util.UUID; + +import io.r2dbc.spi.ConnectionFactories; +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +/** + * Integration tests for {@link DatabasePopulator} using H2. + * + * @author Mark Paluch + */ +public class H2DatabasePopulatorIntegrationTests + extends AbstractDatabaseInitializationTests { + + UUID databaseName = UUID.randomUUID(); + + ConnectionFactory connectionFactory = ConnectionFactories.get("r2dbc:h2:mem:///" + + databaseName + "?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE"); + + + @Override + ConnectionFactory getConnectionFactory() { + return this.connectionFactory; + } + + @Test + public void shouldRunScript() { + + databasePopulator.addScript(usersSchema()); + databasePopulator.addScript(resource("db-test-data-h2.sql")); + // Set statement separator to double newline so that ";" is not + // considered a statement separator within the source code of the + // aliased function 'REVERSE'. + databasePopulator.setSeparator("\n\n"); + + databasePopulator.populate(connectionFactory).as( + StepVerifier::create).verifyComplete(); + + assertUsersDatabaseCreated(connectionFactory, "White"); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulatorUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulatorUnitTests.java new file mode 100644 index 000000000000..af84b9797071 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ResourceDatabasePopulatorUnitTests.java @@ -0,0 +1,117 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.Resource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.BDDMockito.mock; + +/** + * Unit tests for {@link ResourceDatabasePopulator}. + * + * @author Mark Paluch + */ +public class ResourceDatabasePopulatorUnitTests { + + private static final Resource script1 = mock(Resource.class); + + private static final Resource script2 = mock(Resource.class); + + private static final Resource script3 = mock(Resource.class); + + + @Test + public void constructWithNullResource() { + assertThatIllegalArgumentException().isThrownBy( + () -> new ResourceDatabasePopulator((Resource) null)); + } + + @Test + public void constructWithNullResourceArray() { + assertThatIllegalArgumentException().isThrownBy( + () -> new ResourceDatabasePopulator((Resource[]) null)); + } + + @Test + public void constructWithResource() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator( + script1); + assertThat(databasePopulator.scripts).hasSize(1); + } + + @Test + public void constructWithMultipleResources() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator( + script1, script2); + assertThat(databasePopulator.scripts).hasSize(2); + } + + @Test + public void constructWithMultipleResourcesAndThenAddScript() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator( + script1, script2); + assertThat(databasePopulator.scripts).hasSize(2); + + databasePopulator.addScript(script3); + assertThat(databasePopulator.scripts).hasSize(3); + } + + @Test + public void addScriptsWithNullResource() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + assertThatIllegalArgumentException().isThrownBy( + () -> databasePopulator.addScripts((Resource) null)); + } + + @Test + public void addScriptsWithNullResourceArray() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + assertThatIllegalArgumentException().isThrownBy( + () -> databasePopulator.addScripts((Resource[]) null)); + } + + @Test + public void setScriptsWithNullResource() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + assertThatIllegalArgumentException().isThrownBy( + () -> databasePopulator.setScripts((Resource) null)); + } + + @Test + public void setScriptsWithNullResourceArray() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + assertThatIllegalArgumentException().isThrownBy( + () -> databasePopulator.setScripts((Resource[]) null)); + } + + @Test + public void setScriptsAndThenAddScript() { + ResourceDatabasePopulator databasePopulator = new ResourceDatabasePopulator(); + assertThat(databasePopulator.scripts).isEmpty(); + + databasePopulator.setScripts(script1, script2); + assertThat(databasePopulator.scripts).hasSize(2); + + databasePopulator.addScript(script3); + assertThat(databasePopulator.scripts).hasSize(3); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java new file mode 100644 index 000000000000..b72814de704b --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/init/ScriptUtilsUnitTests.java @@ -0,0 +1,219 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.init; + +import java.util.ArrayList; +import java.util.List; + +import org.assertj.core.util.Strings; +import org.junit.jupiter.api.Test; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.support.EncodedResource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ScriptUtils}. + * + * @author Thomas Risberg + * @author Sam Brannen + * @author Phillip Webb + * @author Chris Baldwin + * @author Nicolas Debeissat + * @author Mark Paluch + */ +public class ScriptUtilsUnitTests { + + @Test + public void splitSqlScriptDelimitedWithSemicolon() { + String rawStatement1 = "insert into customer (id, name)\nvalues (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')"; + String cleanedStatement1 = "insert into customer (id, name) values (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')"; + String rawStatement2 = "insert into orders(id, order_date, customer_id)\nvalues (1, '2008-01-02', 2)"; + String cleanedStatement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + String rawStatement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + String cleanedStatement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + + String script = Strings.join(rawStatement1, rawStatement2, rawStatement3).with( + ";"); + + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ";", statements); + + assertThat(statements).hasSize(3).containsSequence(cleanedStatement1, + cleanedStatement2, cleanedStatement3); + } + + @Test + public void splitSqlScriptDelimitedWithNewLine() { + String statement1 = "insert into customer (id, name) values (1, 'Rod ; Johnson'), (2, 'Adrian \n Collier')"; + String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + String statement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + + String script = Strings.join(statement1, statement2, statement3).with("\n"); + + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, "\n", statements); + + assertThat(statements).hasSize(3).containsSequence(statement1, statement2, + statement3); + } + + @Test + public void splitSqlScriptDelimitedWithNewLineButDefaultDelimiterSpecified() { + String statement1 = "do something"; + String statement2 = "do something else"; + + char delim = '\n'; + String script = statement1 + delim + statement2 + delim; + + List statements = new ArrayList<>(); + + ScriptUtils.splitSqlScript(script, ScriptUtils.DEFAULT_STATEMENT_SEPARATOR, + statements); + + assertThat(statements).hasSize(1).contains(script.replace('\n', ' ')); + } + + @Test + public void splitScriptWithSingleQuotesNestedInsideDoubleQuotes() { + String statement1 = "select '1' as \"Dogbert's owner's\" from dual"; + String statement2 = "select '2' as \"Dilbert's\" from dual"; + + char delim = ';'; + String script = statement1 + delim + statement2 + delim; + + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ';', statements); + + assertThat(statements).hasSize(2).containsSequence(statement1, statement2); + } + + @Test + public void readAndSplitScriptWithMultipleNewlinesAsSeparator() { + String script = readScript("db-test-data-multi-newline.sql"); + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, "\n\n", statements); + + String statement1 = "insert into users (last_name) values ('Walter')"; + String statement2 = "insert into users (last_name) values ('Jesse')"; + + assertThat(statements.size()).as("wrong number of statements").isEqualTo(2); + assertThat(statements.get(0)).as("statement 1 not split correctly").isEqualTo( + statement1); + assertThat(statements.get(1)).as("statement 2 not split correctly").isEqualTo( + statement2); + } + + @Test + public void readAndSplitScriptContainingComments() { + String script = readScript("test-data-with-comments.sql"); + splitScriptContainingComments(script); + } + + @Test + public void readAndSplitScriptContainingCommentsWithWindowsLineEnding() { + String script = readScript("test-data-with-comments.sql").replaceAll("\n", + "\r\n"); + splitScriptContainingComments(script); + } + + private void splitScriptContainingComments(String script) { + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ';', statements); + + String statement1 = "insert into customer (id, name) values (1, 'Rod; Johnson'), (2, 'Adrian Collier')"; + String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + String statement3 = "insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2)"; + String statement4 = "INSERT INTO persons( person_id , name) VALUES( 1 , 'Name' )"; + + assertThat(statements).hasSize(4).containsSequence(statement1, statement2, + statement3, statement4); + } + + @Test + public void readAndSplitScriptContainingCommentsWithLeadingTabs() { + String script = readScript("test-data-with-comments-and-leading-tabs.sql"); + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ';', statements); + + String statement1 = "insert into customer (id, name) values (1, 'Walter White')"; + String statement2 = "insert into orders(id, order_date, customer_id) values (1, '2013-06-08', 1)"; + String statement3 = "insert into orders(id, order_date, customer_id) values (2, '2013-06-08', 1)"; + + assertThat(statements).hasSize(3).containsSequence(statement1, statement2, + statement3); + } + + @Test + public void readAndSplitScriptContainingMultiLineComments() { + String script = readScript("test-data-with-multi-line-comments.sql"); + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ';', statements); + + String statement1 = "INSERT INTO users(first_name, last_name) VALUES('Walter', 'White')"; + String statement2 = "INSERT INTO users(first_name, last_name) VALUES( 'Jesse' , 'Pinkman' )"; + + assertThat(statements).hasSize(2).containsSequence(statement1, statement2); + } + + @Test + public void readAndSplitScriptContainingMultiLineNestedComments() { + String script = readScript("test-data-with-multi-line-nested-comments.sql"); + List statements = new ArrayList<>(); + ScriptUtils.splitSqlScript(script, ';', statements); + + String statement1 = "INSERT INTO users(first_name, last_name) VALUES('Walter', 'White')"; + String statement2 = "INSERT INTO users(first_name, last_name) VALUES( 'Jesse' , 'Pinkman' )"; + + assertThat(statements).hasSize(2).containsSequence(statement1, statement2); + } + + @Test + public void containsDelimiters() { + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select ';'", + ";")).isFalse(); + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select 2", + ";")).isTrue(); + + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1; select '\\n\n';", + "\n")).isFalse(); + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2", + "\n")).isTrue(); + + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n select 2", + "\n\n")).isFalse(); + assertThat(ScriptUtils.containsSqlScriptDelimiters("select 1\n\n select 2", + "\n\n")).isTrue(); + + // MySQL style escapes '\\' + assertThat(ScriptUtils.containsSqlScriptDelimiters( + "insert into users(first_name, last_name)\nvalues('a\\\\', 'b;')", + ";")).isFalse(); + assertThat(ScriptUtils.containsSqlScriptDelimiters( + "insert into users(first_name, last_name)\nvalues('Charles', 'd\\'Artagnan'); select 1;", + ";")).isTrue(); + } + + private String readScript(String path) { + EncodedResource resource = new EncodedResource( + new ClassPathResource(path, getClass())); + return ScriptUtils.readScript(resource, new DefaultDataBufferFactory()).block(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactoryUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactoryUnitTests.java new file mode 100644 index 000000000000..233e22dbc5be --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/AbstractRoutingConnectionFactoryUnitTests.java @@ -0,0 +1,195 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.context.Context; + +import static java.util.Collections.singletonMap; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link AbstractRoutingConnectionFactory}. + * + * @author Mark Paluch + * @author Jens Schauder + */ +@ExtendWith(MockitoExtension.class) +public class AbstractRoutingConnectionFactoryUnitTests { + + private static final String ROUTING_KEY = "routingKey"; + + @Mock + ConnectionFactory defaultConnectionFactory; + + @Mock + ConnectionFactory routedConnectionFactory; + + DummyRoutingConnectionFactory connectionFactory; + + @BeforeEach + public void before() { + connectionFactory = new DummyRoutingConnectionFactory(); + connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory); + } + + @Test + public void shouldDetermineRoutedFactory() { + + connectionFactory.setTargetConnectionFactories( + singletonMap("key", routedConnectionFactory)); + connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup()); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "key")) + .as(StepVerifier::create) + .expectNext(routedConnectionFactory) + .verifyComplete(); + } + + @Test + public void shouldFallbackToDefaultConnectionFactory() { + connectionFactory.setTargetConnectionFactories( + singletonMap("key", routedConnectionFactory)); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .as(StepVerifier::create) + .expectNext(defaultConnectionFactory) + .verifyComplete(); + } + + @Test + public void initializationShouldFailUnsupportedLookupKey() { + connectionFactory.setTargetConnectionFactories(singletonMap("key", new Object())); + + assertThatThrownBy(() -> connectionFactory.afterPropertiesSet()).isInstanceOf( + IllegalArgumentException.class); + } + + @Test + public void initializationShouldFailUnresolvableKey() { + connectionFactory.setTargetConnectionFactories(singletonMap("key", "value")); + connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup()); + + assertThatThrownBy(() -> connectionFactory.afterPropertiesSet()) + .isInstanceOf(ConnectionFactoryLookupFailureException.class) + .hasMessageContaining( + "No ConnectionFactory with name 'value' registered"); + } + + @Test + public void unresolvableConnectionFactoryRetrievalShouldFail() { + connectionFactory.setLenientFallback(false); + connectionFactory.setConnectionFactoryLookup(new MapConnectionFactoryLookup()); + connectionFactory.setTargetConnectionFactories( + singletonMap("key", routedConnectionFactory)); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "unknown")) + .as(StepVerifier::create) + .verifyError(IllegalStateException.class); + } + + @Test + public void connectionFactoryRetrievalWithUnknownLookupKeyShouldReturnDefaultConnectionFactory() { + connectionFactory.setTargetConnectionFactories( + singletonMap("key", routedConnectionFactory)); + connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "unknown")) + .as(StepVerifier::create) + .expectNext(defaultConnectionFactory) + .verifyComplete(); + } + + @Test + public void connectionFactoryRetrievalWithoutLookupKeyShouldReturnDefaultConnectionFactory() { + connectionFactory.setTargetConnectionFactories( + singletonMap("key", routedConnectionFactory)); + connectionFactory.setDefaultTargetConnectionFactory(defaultConnectionFactory); + connectionFactory.setLenientFallback(false); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .as(StepVerifier::create) + .expectNext(defaultConnectionFactory) + .verifyComplete(); + } + + @Test + public void shouldLookupFromMap() { + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup("lookup-key", + routedConnectionFactory); + + connectionFactory.setConnectionFactoryLookup(lookup); + connectionFactory.setTargetConnectionFactories( + singletonMap("my-key", "lookup-key")); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "my-key")) + .as(StepVerifier::create) + .expectNext(routedConnectionFactory) + .verifyComplete(); + } + + @Test + public void shouldAllowModificationsAfterInitialization() { + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(); + + connectionFactory.setConnectionFactoryLookup(lookup); + connectionFactory.setTargetConnectionFactories(lookup.getConnectionFactories()); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "lookup-key")) + .as(StepVerifier::create) + .expectNext(defaultConnectionFactory) + .verifyComplete(); + + lookup.addConnectionFactory("lookup-key", routedConnectionFactory); + connectionFactory.afterPropertiesSet(); + + connectionFactory.determineTargetConnectionFactory() + .subscriberContext(Context.of(ROUTING_KEY, "lookup-key")) + .as(StepVerifier::create) + .expectNext(routedConnectionFactory) + .verifyComplete(); + } + + static class DummyRoutingConnectionFactory extends AbstractRoutingConnectionFactory { + + @Override + protected Mono determineCurrentLookupKey() { + return Mono.subscriberContext().filter(context -> context.hasKey(ROUTING_KEY)) + .map(context -> context.get(ROUTING_KEY)); + } + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookupUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookupUnitTests.java new file mode 100644 index 000000000000..0e2bf0a3538a --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/BeanFactoryConnectionFactoryLookupUnitTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanNotOfRequiredTypeException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link BeanFactoryConnectionFactoryLookup}. + * + * @author Mark Paluch + */ +@ExtendWith(MockitoExtension.class) +public class BeanFactoryConnectionFactoryLookupUnitTests { + + private static final String CONNECTION_FACTORY_BEAN_NAME = "connectionFactory"; + + @Mock + BeanFactory beanFactory; + + @Test + public void shouldLookupConnectionFactory() { + DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory(); + when(beanFactory.getBean(CONNECTION_FACTORY_BEAN_NAME, + ConnectionFactory.class)).thenReturn(expectedConnectionFactory); + + BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup(); + lookup.setBeanFactory(beanFactory); + + ConnectionFactory connectionFactory = lookup.getConnectionFactory( + CONNECTION_FACTORY_BEAN_NAME); + + assertThat(connectionFactory).isNotNull(); + assertThat(connectionFactory).isSameAs(expectedConnectionFactory); + } + + @Test + public void shouldLookupWhereBeanFactoryYieldsNonConnectionFactoryType() { + BeanFactory beanFactory = mock(BeanFactory.class); + + when(beanFactory.getBean(CONNECTION_FACTORY_BEAN_NAME, + ConnectionFactory.class)).thenThrow( + new BeanNotOfRequiredTypeException(CONNECTION_FACTORY_BEAN_NAME, + ConnectionFactory.class, String.class)); + + BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup( + beanFactory); + + assertThatExceptionOfType( + ConnectionFactoryLookupFailureException.class).isThrownBy( + () -> lookup.getConnectionFactory(CONNECTION_FACTORY_BEAN_NAME)); + } + + @Test + public void shouldLookupWhereBeanFactoryHasNotBeenSupplied() { + BeanFactoryConnectionFactoryLookup lookup = new BeanFactoryConnectionFactoryLookup(); + + assertThatThrownBy(() -> lookup.getConnectionFactory( + CONNECTION_FACTORY_BEAN_NAME)).isInstanceOf(IllegalStateException.class); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/DummyConnectionFactory.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/DummyConnectionFactory.java new file mode 100644 index 000000000000..1d3b8a849b1b --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/DummyConnectionFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import org.reactivestreams.Publisher; + +/** + * Stub, do-nothing {@link ConnectionFactory} implementation. + *

+ * All methods throw {@link UnsupportedOperationException}. + * + * @author Mark Paluch + */ +class DummyConnectionFactory implements ConnectionFactory { + + @Override + public Publisher create() { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + throw new UnsupportedOperationException(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookupUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookupUnitTests.java new file mode 100644 index 000000000000..a61abdbdfaff --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/connection/lookup/MapConnectionFactoryLookupUnitTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.connection.lookup; + +import java.util.HashMap; +import java.util.Map; + +import io.r2dbc.spi.ConnectionFactory; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MapConnectionFactoryLookup}. + * + * @author Mark Paluch + */ +public class MapConnectionFactoryLookupUnitTests { + + private static final String CONNECTION_FACTORY_NAME = "connectionFactory"; + + @Test + public void getConnectionFactorysReturnsUnmodifiableMap() { + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(); + Map connectionFactories = lookup.getConnectionFactories(); + + assertThatThrownBy(() -> connectionFactories.put("", + new DummyConnectionFactory())).isInstanceOf( + UnsupportedOperationException.class); + } + + @Test + public void shouldLookupConnectionFactory() { + Map connectionFactories = new HashMap<>(); + DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory(); + + connectionFactories.put(CONNECTION_FACTORY_NAME, expectedConnectionFactory); + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(); + + lookup.setConnectionFactories(connectionFactories); + + ConnectionFactory connectionFactory = lookup.getConnectionFactory( + CONNECTION_FACTORY_NAME); + + assertThat(connectionFactory).isNotNull().isSameAs(expectedConnectionFactory); + } + + @Test + public void addingConnectionFactoryPermitsOverride() { + Map connectionFactories = new HashMap<>(); + DummyConnectionFactory overriddenConnectionFactory = new DummyConnectionFactory(); + DummyConnectionFactory expectedConnectionFactory = new DummyConnectionFactory(); + connectionFactories.put(CONNECTION_FACTORY_NAME, overriddenConnectionFactory); + + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(); + + lookup.setConnectionFactories(connectionFactories); + lookup.addConnectionFactory(CONNECTION_FACTORY_NAME, expectedConnectionFactory); + + ConnectionFactory connectionFactory = lookup.getConnectionFactory( + CONNECTION_FACTORY_NAME); + + assertThat(connectionFactory).isNotNull().isSameAs(expectedConnectionFactory); + } + + @Test + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void getConnectionFactoryWhereSuppliedMapHasNonConnectionFactoryTypeUnderSpecifiedKey() { + Map connectionFactories = new HashMap<>(); + connectionFactories.put(CONNECTION_FACTORY_NAME, new Object()); + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup( + connectionFactories); + + assertThatThrownBy( + () -> lookup.getConnectionFactory(CONNECTION_FACTORY_NAME)).isInstanceOf( + ClassCastException.class); + } + + @Test + public void getConnectionFactoryWhereSuppliedMapHasNoEntryForSpecifiedKey() { + MapConnectionFactoryLookup lookup = new MapConnectionFactoryLookup(); + + assertThatThrownBy( + () -> lookup.getConnectionFactory(CONNECTION_FACTORY_NAME)).isInstanceOf( + ConnectionFactoryLookupFailureException.class); + } +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java new file mode 100644 index 000000000000..dff746cd59e6 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractDatabaseClientIntegrationTests.java @@ -0,0 +1,152 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Result; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.dao.DataIntegrityViolationException; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link DatabaseClient}. + * + * @author Mark Paluch + * @author Mingyuan Wu + */ +public abstract class AbstractDatabaseClientIntegrationTests { + + private ConnectionFactory connectionFactory; + + @BeforeEach + public void before() { + connectionFactory = createConnectionFactory(); + + Mono.from(connectionFactory.create()) + .flatMapMany(connection -> Flux.from(connection.createStatement("DROP TABLE legoset").execute()) + .flatMap(Result::getRowsUpdated) + .onErrorResume(e -> Mono.empty()) + .thenMany(connection.createStatement(getCreateTableStatement()).execute()) + .flatMap(Result::getRowsUpdated).thenMany(connection.close())).as(StepVerifier::create) + .verifyComplete(); + } + + /** + * Creates a {@link ConnectionFactory} to be used in this test. + * + * @return the {@link ConnectionFactory} to be used in this test + */ + protected abstract ConnectionFactory createConnectionFactory(); + + /** + * Return the the CREATE TABLE statement for table {@code legoset} with the following + * three columns: + *

    + *
  • id integer (primary key), not null
  • + *
  • name varchar(255), nullable
  • + *
  • manual integer, nullable
  • + *
+ * + * @return the CREATE TABLE statement for table {@code legoset} with three columns. + */ + protected abstract String getCreateTableStatement(); + + @Test + public void executeInsert() { + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + databaseClient.sql("INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)") + .bind("id", 42055) + .bind("name", "SCHAUFELRADBAGGER") + .bindNull("manual", Integer.class) + .fetch().rowsUpdated() + .as(StepVerifier::create) + .expectNext(1) + .verifyComplete(); + + databaseClient.sql("SELECT id FROM legoset") + .map(row -> row.get("id")) + .first() + .as(StepVerifier::create) + .assertNext(actual -> { + assertThat(actual).isInstanceOf(Number.class); + assertThat(((Number) actual).intValue()).isEqualTo(42055); + }).verifyComplete(); + } + + @Test + public void shouldTranslateDuplicateKeyException() { + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + executeInsert(); + + databaseClient.sql( + "INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)") + .bind("id", 42055) + .bind("name", "SCHAUFELRADBAGGER") + .bindNull("manual", Integer.class) + .fetch().rowsUpdated() + .as(StepVerifier::create) + .expectErrorSatisfies(exception -> assertThat(exception) + .isInstanceOf(DataIntegrityViolationException.class) + .hasMessageContaining("execute; SQL [INSERT INTO legoset")) + .verify(); + } + + @Test + public void executeDeferred() { + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + databaseClient.sql(() -> "INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)") + .bind("id", 42055) + .bind("name", "SCHAUFELRADBAGGER") + .bindNull("manual", Integer.class) + .fetch().rowsUpdated() + .as(StepVerifier::create) + .expectNext(1) + .verifyComplete(); + + databaseClient.sql("SELECT id FROM legoset") + .map(row -> row.get("id")).first() + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @Test + public void shouldEmitGeneratedKey() { + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + databaseClient.sql( + "INSERT INTO legoset ( name, manual) VALUES(:name, :manual)") + .bind("name","SCHAUFELRADBAGGER") + .bindNull("manual", Integer.class) + .filter(statement -> statement.returnGeneratedValues("id")) + .map(row -> (Number) row.get("id")) + .first() + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractTransactionalDatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractTransactionalDatabaseClientIntegrationTests.java new file mode 100644 index 000000000000..1d90bc7a45cd --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/AbstractTransactionalDatabaseClientIntegrationTests.java @@ -0,0 +1,208 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Result; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.r2dbc.connection.R2dbcTransactionManager; +import org.springframework.transaction.ReactiveTransactionManager; +import org.springframework.transaction.reactive.TransactionalOperator; +import org.springframework.transaction.support.DefaultTransactionDefinition; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Abstract base class for transactional integration tests for {@link DatabaseClient}. + * + * @author Mark Paluch + * @author Christoph Strobl + */ +public abstract class AbstractTransactionalDatabaseClientIntegrationTests { + + private ConnectionFactory connectionFactory; + + AnnotationConfigApplicationContext context; + + DatabaseClient databaseClient; + R2dbcTransactionManager transactionManager; + TransactionalOperator rxtx; + + @BeforeEach + public void before() { + + connectionFactory = createConnectionFactory(); + + context = new AnnotationConfigApplicationContext(); + context.getBeanFactory().registerResolvableDependency(ConnectionFactory.class, connectionFactory); + context.register(Config.class); + context.refresh(); + + + Mono.from(connectionFactory.create()) + .flatMapMany(connection -> Flux.from(connection.createStatement("DROP TABLE legoset").execute()) + .flatMap(Result::getRowsUpdated) + .onErrorResume(e -> Mono.empty()) + .thenMany(connection.createStatement(getCreateTableStatement()).execute()) + .flatMap(Result::getRowsUpdated).thenMany(connection.close())).as(StepVerifier::create).verifyComplete(); + + databaseClient = DatabaseClient.create(connectionFactory); + transactionManager = new R2dbcTransactionManager(connectionFactory); + rxtx = TransactionalOperator.create(transactionManager); + } + + @AfterEach + public void tearDown() { + context.close(); + } + + /** + * Create a {@link ConnectionFactory} to be used in this test. + * @return the {@link ConnectionFactory} to be used in this test. + */ + protected abstract ConnectionFactory createConnectionFactory(); + + /** + * Return the the CREATE TABLE statement for table {@code legoset} with the following three columns: + *
    + *
  • id integer (primary key), not null
  • + *
  • name varchar(255), nullable
  • + *
  • manual integer, nullable
  • + *
+ * + * @return the CREATE TABLE statement for table {@code legoset} with three columns. + */ + protected abstract String getCreateTableStatement(); + + /** + * Get a parameterized {@code INSERT INTO legoset} statement setting id, name, and manual values. + */ + protected String getInsertIntoLegosetStatement() { + return "INSERT INTO legoset (id, name, manual) VALUES(:id, :name, :manual)"; + } + + @Test + public void executeInsertInTransaction() { + Flux integerFlux = databaseClient + .sql(getInsertIntoLegosetStatement()) + .bind(0, 42055) + .bind(1, "SCHAUFELRADBAGGER") + .bindNull(2, Integer.class) + .fetch().rowsUpdated().flux().as(rxtx::transactional); + + integerFlux.as(StepVerifier::create) + .expectNext(1) + .verifyComplete(); + + databaseClient + .sql("SELECT id FROM legoset") + .fetch() + .first() + .as(StepVerifier::create) + .assertNext(actual -> assertThat(actual).hasEntrySatisfying("id", numberOf(42055))) + .verifyComplete(); + } + + @Test + public void shouldRollbackTransaction() { + + Mono integerFlux = databaseClient.sql(getInsertIntoLegosetStatement()) + .bind(0, 42055) + .bind(1, "SCHAUFELRADBAGGER") + .bindNull(2, Integer.class) + .fetch().rowsUpdated() + .then(Mono.error(new IllegalStateException("failed"))) + .as(rxtx::transactional); + + integerFlux.as(StepVerifier::create) + .expectError(IllegalStateException.class) + .verify(); + + databaseClient + .sql("SELECT id FROM legoset") + .fetch() + .first() + .as(StepVerifier::create) + .verifyComplete(); + } + + @Test + public void shouldRollbackTransactionUsingTransactionalOperator() { + + DatabaseClient databaseClient = DatabaseClient.create(connectionFactory); + + TransactionalOperator transactionalOperator = TransactionalOperator + .create(new R2dbcTransactionManager(connectionFactory), new DefaultTransactionDefinition()); + + Flux integerFlux = databaseClient.sql(getInsertIntoLegosetStatement()) + .bind(0, 42055) + .bind(1, "SCHAUFELRADBAGGER") + .bindNull(2, Integer.class) + .fetch().rowsUpdated() + .thenMany(Mono.fromSupplier(() -> { + throw new IllegalStateException("failed"); + })); + + integerFlux.as(transactionalOperator::transactional) + .as(StepVerifier::create) + .expectError(IllegalStateException.class) + .verify(); + + databaseClient + .sql("SELECT id FROM legoset") + .fetch() + .first() + .as(StepVerifier::create) + .verifyComplete(); + } + + private Condition numberOf(int expected) { + return new Condition<>(object -> object instanceof Number && + ((Number) object).intValue() == expected, "Number %d", expected); + } + + + @Configuration(proxyBeanMethods = false) + static class Config { + + @Autowired GenericApplicationContext context; + + @Bean + ReactiveTransactionManager txMgr(ConnectionFactory connectionFactory) { + return new R2dbcTransactionManager(connectionFactory); + } + + @Bean + TransactionalOperator transactionalOperator(ReactiveTransactionManager transactionManager) { + return TransactionalOperator.create(transactionManager); + } + + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java new file mode 100644 index 000000000000..61f593b377fa --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/DefaultDatabaseClientUnitTests.java @@ -0,0 +1,435 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.Arrays; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import io.r2dbc.spi.test.MockColumnMetadata; +import io.r2dbc.spi.test.MockResult; +import io.r2dbc.spi.test.MockRow; +import io.r2dbc.spi.test.MockRowMetadata; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.lang.Nullable; +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindTarget; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.doReturn; +import static org.mockito.BDDMockito.inOrder; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.times; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoInteractions; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; +import static org.mockito.BDDMockito.when; + +/** + * Unit tests for {@link DefaultDatabaseClient}. + * + * @author Mark Paluch + * @author Ferdinand Jacobs + * @author Jens Schauder + */ +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) +public class DefaultDatabaseClientUnitTests { + + @Mock + Connection connection; + + private DatabaseClient.Builder databaseClientBuilder; + + @BeforeEach + public void before() { + ConnectionFactory connectionFactory = mock(ConnectionFactory.class); + + when(connectionFactory.create()).thenReturn((Publisher) Mono.just(connection)); + when(connection.close()).thenReturn(Mono.empty()); + + databaseClientBuilder = DatabaseClient.builder().connectionFactory( + connectionFactory).bindMarkers(BindMarkersFactory.indexed("$", 1)); + } + + @Test + public void shouldCloseConnectionOnlyOnce() { + DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build(); + + Flux flux = databaseClient.inConnectionMany(connection -> Flux.empty()); + + flux.subscribe(new CoreSubscriber() { + + Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + s.request(1); + subscription = s; + } + + @Override + public void onNext(Object o) { + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onComplete() { + subscription.cancel(); + } + }); + + verify(connection, times(1)).close(); + } + + @Test + public void executeShouldBindNullValues() { + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + + DatabaseClient databaseClient = databaseClientBuilder.namedParameters(false).build(); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bindNull(0, + String.class).then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bindNull(0, String.class); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bindNull("$1", + String.class).then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bindNull("$1", String.class); + } + + @Test + public void executeShouldBindSettableValues() { + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + + DatabaseClient databaseClient = databaseClientBuilder.namedParameters(false).build(); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bind(0, + Parameter.empty(String.class)).then().as( + StepVerifier::create).verifyComplete(); + + verify(statement).bindNull(0, String.class); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bind("$1", + Parameter.empty(String.class)).then().as( + StepVerifier::create).verifyComplete(); + + verify(statement).bindNull("$1", String.class); + } + + @Test + public void executeShouldBindNamedNullValues() { + + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM table WHERE key = :key").bindNull("key", + String.class).then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bindNull(0, String.class); + } + + @Test + public void executeShouldBindNamedValuesFromIndexes() { + Statement statement = mockStatementFor( + "SELECT id, name, manual FROM legoset WHERE name IN ($1, $2, $3)"); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql( + "SELECT id, name, manual FROM legoset WHERE name IN (:name)").bind(0, + Arrays.asList("unknown", "dunno", "other")).then().as( + StepVerifier::create).verifyComplete(); + + verify(statement).bind(0, "unknown"); + verify(statement).bind(1, "dunno"); + verify(statement).bind(2, "other"); + verify(statement).execute(); + verifyNoMoreInteractions(statement); + } + + @Test + public void executeShouldBindValues() { + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bind(0, + Parameter.from("foo")).then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bind(0, "foo"); + + databaseClient.sql("SELECT * FROM table WHERE key = $1").bind("$1", + "foo").then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bind("$1", "foo"); + } + + @Test + public void executeShouldBindNamedValuesByIndex() { + + Statement statement = mockStatementFor("SELECT * FROM table WHERE key = $1"); + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM table WHERE key = :key").bind("key", + "foo").then().as(StepVerifier::create).verifyComplete(); + + verify(statement).bind(0, "foo"); + } + + @Test + public void rowsUpdatedShouldEmitSingleValue() { + + Result result = mock(Result.class); + when(result.getRowsUpdated()).thenReturn(Mono.empty(), Mono.just(2), + Flux.just(1, 2, 3)); + mockStatementFor("DROP TABLE tab;", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as( + StepVerifier::create).expectNextCount(1).verifyComplete(); + + databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as( + StepVerifier::create).expectNextCount(1).verifyComplete(); + + databaseClient.sql("DROP TABLE tab;").fetch().rowsUpdated().as( + StepVerifier::create).expectNextCount(1).verifyComplete(); + } + + @Test + public void selectShouldEmitFirstValue() { + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("name").build()).build(); + + MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata); + MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build()) + .row(MockRow.builder().identified(0, Object.class, "White").build()).build(); + + mockStatementFor("SELECT * FROM person", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM person").map(row -> row.get(0)) + .first() + .as(StepVerifier::create) + .expectNext("Walter") + .verifyComplete(); + } + + @Test + public void selectShouldEmitAllValues() { + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("name").build()).build(); + + MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata); + MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build()) + .row(MockRow.builder().identified(0, Object.class, "White").build()).build(); + + mockStatementFor("SELECT * FROM person", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM person").map(row -> row.get(0)) + .all() + .as(StepVerifier::create) + .expectNext("Walter") + .expectNext("White") + .verifyComplete(); + } + + @Test + public void selectOneShouldFailWithException() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("name").build()).build(); + + MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata); + MockResult result = resultBuilder.row(MockRow.builder().identified(0, Object.class, "Walter").build()) + .row(MockRow.builder().identified(0, Object.class, "White").build()).build(); + + mockStatementFor("SELECT * FROM person", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT * FROM person").map(row -> row.get(0)) + .one() + .as(StepVerifier::create) + .verifyError(IncorrectResultSizeDataAccessException.class); + } + + @Test + public void shouldApplyExecuteFunction() { + + Statement statement = mockStatement(); + MockResult result = mockSingleColumnResult( + MockRow.builder().identified(0, Object.class, "Walter")); + + DatabaseClient databaseClient = databaseClientBuilder.executeFunction( + stmnt -> Mono.just(result)).build(); + + databaseClient.sql("SELECT").fetch().all().as( + StepVerifier::create).expectNextCount(1).verifyComplete(); + + verifyNoInteractions(statement); + } + + @Test + public void shouldApplyPreparedOperation() { + + MockResult result = mockSingleColumnResult( + MockRow.builder().identified(0, Object.class, "Walter")); + Statement statement = mockStatementFor("SELECT * FROM person", result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql(new PreparedOperation() { + + @Override + public String toQuery() { + return "SELECT * FROM person"; + } + + @Override + public String getSource() { + return "SELECT"; + } + + @Override + public void bindTo(BindTarget target) { + target.bind("index", "value"); + } + }).fetch().all().as( + StepVerifier::create).expectNextCount(1).verifyComplete(); + + verify(statement).bind("index", "value"); + } + + @Test + public void shouldApplyStatementFilterFunctions() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).build(); + + Statement statement = mockStatement(result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT").filter( + (s, next) -> next.execute(s.returnGeneratedValues("foo"))).filter( + (s, next) -> next.execute( + s.returnGeneratedValues("bar"))).fetch().all().as( + StepVerifier::create).verifyComplete(); + + InOrder inOrder = inOrder(statement); + inOrder.verify(statement).returnGeneratedValues("foo"); + inOrder.verify(statement).returnGeneratedValues("bar"); + inOrder.verify(statement).execute(); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void shouldApplySimpleStatementFilterFunctions() { + + MockResult result = mockSingleColumnEmptyResult(); + + Statement statement = mockStatement(result); + + DatabaseClient databaseClient = databaseClientBuilder.build(); + + databaseClient.sql("SELECT").filter( + s -> s.returnGeneratedValues("foo")).filter( + s -> s.returnGeneratedValues("bar")).fetch().all().as( + StepVerifier::create).verifyComplete(); + + InOrder inOrder = inOrder(statement); + inOrder.verify(statement).returnGeneratedValues("foo"); + inOrder.verify(statement).returnGeneratedValues("bar"); + inOrder.verify(statement).execute(); + inOrder.verifyNoMoreInteractions(); + } + + private Statement mockStatement() { + return mockStatementFor(null, null); + } + + private Statement mockStatement(Result result) { + return mockStatementFor(null, result); + } + + private Statement mockStatementFor(String sql) { + return mockStatementFor(sql, null); + } + + private Statement mockStatementFor(@Nullable String sql, @Nullable Result result) { + + Statement statement = mock(Statement.class); + when(connection.createStatement(sql == null ? anyString() : eq(sql))).thenReturn( + statement); + when(statement.returnGeneratedValues(anyString())).thenReturn(statement); + when(statement.returnGeneratedValues()).thenReturn(statement); + + doReturn(result == null ? Mono.empty() : Flux.just(result)).when( + statement).execute(); + + return statement; + } + + private MockResult mockSingleColumnEmptyResult() { + return mockSingleColumnResult(null); + } + + /** + * Mocks a {@link Result} with a single column "name" and a single row if a non null + * row is provided. + */ + private MockResult mockSingleColumnResult(@Nullable MockRow.Builder row) { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata( + MockColumnMetadata.builder().name("name").build()).build(); + + MockResult.Builder resultBuilder = MockResult.builder().rowMetadata(metadata); + if (row != null) { + resultBuilder = resultBuilder.row(row.build()); + } + return resultBuilder.build(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2DatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2DatabaseClientIntegrationTests.java new file mode 100644 index 000000000000..7bf3f3883cbf --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2DatabaseClientIntegrationTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.h2.H2ConnectionFactory; +import io.r2dbc.spi.ConnectionFactory; + +/** + * Integration tests for {@link DatabaseClient} against H2. + * + * @author Mark Paluch + */ +public class H2DatabaseClientIntegrationTests + extends AbstractDatabaseClientIntegrationTests { + + public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + + " id serial CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + + " name varchar(255) NOT NULL,\n" // + + " manual integer NULL\n" // + + ");"; + + @Override + protected ConnectionFactory createConnectionFactory() { + return H2ConnectionFactory.inMemory("r2dbc-test"); + } + + @Override + protected String getCreateTableStatement() { + return CREATE_TABLE_LEGOSET; + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2TransactionalDatabaseClientIntegrationTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2TransactionalDatabaseClientIntegrationTests.java new file mode 100644 index 000000000000..e326f851ab1e --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/H2TransactionalDatabaseClientIntegrationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import io.r2dbc.h2.H2ConnectionFactory; +import io.r2dbc.spi.ConnectionFactory; + +/** + * Integration tests for {@link DatabaseClient} against H2. + * + * @author Mark Paluch + */ +public class H2TransactionalDatabaseClientIntegrationTests + extends AbstractTransactionalDatabaseClientIntegrationTests { + + public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + + " id integer CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + + " name varchar(255) NOT NULL,\n" // + + " manual integer NULL\n" // + + ");"; + + @Override + protected ConnectionFactory createConnectionFactory() { + return H2ConnectionFactory.inMemory("r2dbc-transactional"); + } + + @Override + protected String getCreateTableStatement() { + return CREATE_TABLE_LEGOSET; + } +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java new file mode 100644 index 000000000000..f332b222e8b1 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/NamedParameterUtilsUnitTests.java @@ -0,0 +1,462 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import org.springframework.r2dbc.core.binding.BindMarkersFactory; +import org.springframework.r2dbc.core.binding.BindTarget; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for {@link NamedParameterUtils}. + * + * @author Mark Paluch + * @author Jens Schauder + */ +public class NamedParameterUtilsUnitTests { + + private final BindMarkersFactory BIND_MARKERS = BindMarkersFactory.indexed("$", 1); + + @Test + public void shouldParseSql() { + String sql = "xxx :a yyyy :b :c :a zzzzz"; + ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(psql.getParameterNames()).containsExactly("a", "b", "c", "a"); + assertThat(psql.getTotalParameterCount()).isEqualTo(4); + assertThat(psql.getNamedParameterCount()).isEqualTo(3); + + String sql2 = "xxx &a yyyy ? zzzzz"; + ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2); + assertThat(psql2.getParameterNames()).containsExactly("a"); + assertThat(psql2.getTotalParameterCount()).isEqualTo(1); + assertThat(psql2.getNamedParameterCount()).isEqualTo(1); + + String sql3 = "xxx &ä+:ö" + '\t' + ":ü%10 yyyy ? zzzzz"; + ParsedSql psql3 = NamedParameterUtils.parseSqlStatement(sql3); + assertThat(psql3.getParameterNames()).containsExactly("ä", "ö", "ü"); + } + + @Test + public void substituteNamedParameters() { + MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>()); + namedParams.addValue("a", "a").addValue("b", "b").addValue("c", "c"); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + "xxx :a :b :c", BIND_MARKERS, namedParams); + + assertThat(operation.toQuery()).isEqualTo("xxx $1 $2 $3"); + + PreparedOperation operation2 = NamedParameterUtils.substituteNamedParameters( + "xxx :a :b :c", BindMarkersFactory.named("@", "P", 8), namedParams); + + assertThat(operation2.toQuery()).isEqualTo("xxx @P0a @P1b @P2c"); + } + + @Test + public void substituteObjectArray() { + MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>()); + namedParams.addValue("a", Arrays.asList(new Object[] { "Walter", "Heisenberg" }, + new Object[] { "Walt Jr.", "Flynn" })); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + "xxx :a", BIND_MARKERS, namedParams); + + assertThat(operation.toQuery()).isEqualTo("xxx ($1, $2), ($3, $4)"); + } + + @Test + public void shouldBindObjectArray() { + MapBindParameterSource namedParams = new MapBindParameterSource(new HashMap<>()); + namedParams.addValue("a", Arrays.asList(new Object[] { "Walter", "Heisenberg" }, + new Object[] { "Walt Jr.", "Flynn" })); + + BindTarget bindTarget = mock(BindTarget.class); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + "xxx :a", BIND_MARKERS, namedParams); + operation.bindTo(bindTarget); + + verify(bindTarget).bind(0, "Walter"); + verify(bindTarget).bind(1, "Heisenberg"); + verify(bindTarget).bind(2, "Walt Jr."); + verify(bindTarget).bind(3, "Flynn"); + } + + @Test + public void parseSqlContainingComments() { + String sql1 = "/*+ HINT */ xxx /* comment ? */ :a yyyy :b :c :a zzzzz -- :xx XX\n"; + + ParsedSql psql1 = NamedParameterUtils.parseSqlStatement(sql1); + assertThat(expand(psql1)).isEqualTo( + "/*+ HINT */ xxx /* comment ? */ $1 yyyy $2 $3 $1 zzzzz -- :xx XX\n"); + + MapBindParameterSource paramMap = new MapBindParameterSource(new HashMap<>()); + paramMap.addValue("a", "a"); + paramMap.addValue("b", "b"); + paramMap.addValue("c", "c"); + + String sql2 = "/*+ HINT */ xxx /* comment ? */ :a yyyy :b :c :a zzzzz -- :xx XX"; + ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2); + assertThat(expand(psql2)).isEqualTo( + "/*+ HINT */ xxx /* comment ? */ $1 yyyy $2 $3 $1 zzzzz -- :xx XX"); + } + + @Test + public void parseSqlStatementWithPostgresCasting() { + String expectedSql = "select 'first name' from artists where id = $1 and birth_date=$2::timestamp"; + String sql = "select 'first name' from artists where id = :id and birth_date=:birthDate::timestamp"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + parsedSql, BIND_MARKERS, new MapBindParameterSource()); + + assertThat(operation.toQuery()).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithPostgresContainedOperator() { + String expectedSql = "select 'first name' from artists where info->'stat'->'albums' = ?? $1 and '[\"1\",\"2\",\"3\"]'::jsonb ?? '4'"; + String sql = "select 'first name' from artists where info->'stat'->'albums' = ?? :album and '[\"1\",\"2\",\"3\"]'::jsonb ?? '4'"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(parsedSql.getTotalParameterCount()).isEqualTo(1); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithPostgresAnyArrayStringsExistsOperator() { + String expectedSql = "select '[\"3\", \"11\"]'::jsonb ?| '{1,3,11,12,17}'::text[]"; + String sql = "select '[\"3\", \"11\"]'::jsonb ?| '{1,3,11,12,17}'::text[]"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(parsedSql.getTotalParameterCount()).isEqualTo(0); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithPostgresAllArrayStringsExistsOperator() { + String expectedSql = "select '[\"3\", \"11\"]'::jsonb ?& '{1,3,11,12,17}'::text[] AND $1 = 'Back in Black'"; + String sql = "select '[\"3\", \"11\"]'::jsonb ?& '{1,3,11,12,17}'::text[] AND :album = 'Back in Black'"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(parsedSql.getTotalParameterCount()).isEqualTo(1); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithEscapedColon() { + String expectedSql = "select '0\\:0' as a, foo from bar where baz < DATE($1 23:59:59) and baz = $2"; + String sql = "select '0\\:0' as a, foo from bar where baz < DATE(:p1 23\\:59\\:59) and baz = :p2"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(parsedSql.getParameterNames()).containsExactly("p1", "p2"); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithBracketDelimitedParameterNames() { + String expectedSql = "select foo from bar where baz = b$1$2z"; + String sql = "select foo from bar where baz = b:{p1}:{p2}z"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(parsedSql.getParameterNames()).containsExactly("p1", "p2"); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithEmptyBracketsOrBracketsInQuotes() { + String expectedSql = "select foo from bar where baz = b:{}z"; + String sql = "select foo from bar where baz = b:{}z"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(parsedSql.getParameterNames()).isEmpty(); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + + String expectedSql2 = "select foo from bar where baz = 'b:{p1}z'"; + String sql2 = "select foo from bar where baz = 'b:{p1}z'"; + + ParsedSql parsedSql2 = NamedParameterUtils.parseSqlStatement(sql2); + assertThat(parsedSql2.getParameterNames()).isEmpty(); + assertThat(expand(parsedSql2)).isEqualTo(expectedSql2); + } + + @Test + public void parseSqlStatementWithSingleLetterInBrackets() { + String expectedSql = "select foo from bar where baz = b$1z"; + String sql = "select foo from bar where baz = b:{p}z"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); + assertThat(parsedSql.getParameterNames()).containsExactly("p"); + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithLogicalAnd() { + String expectedSql = "xxx & yyyy"; + + ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(expectedSql); + + assertThat(expand(parsedSql)).isEqualTo(expectedSql); + } + + @Test + public void substituteNamedParametersWithLogicalAnd() { + + String expectedSql = "xxx & yyyy"; + + assertThat(expand(expectedSql)).isEqualTo(expectedSql); + } + + @Test + public void variableAssignmentOperator() { + String expectedSql = "x := 1"; + + assertThat(expand(expectedSql)).isEqualTo(expectedSql); + } + + @Test + public void parseSqlStatementWithQuotedSingleQuote() { + String sql = "SELECT ':foo'':doo', :xxx FROM DUAL"; + + ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(psql.getTotalParameterCount()).isEqualTo(1); + assertThat(psql.getParameterNames()).containsExactly("xxx"); + } + + @Test + public void parseSqlStatementWithQuotesAndCommentBefore() { + String sql = "SELECT /*:doo*/':foo', :xxx FROM DUAL"; + + ParsedSql psql = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(psql.getTotalParameterCount()).isEqualTo(1); + assertThat(psql.getParameterNames()).containsExactly("xxx"); + } + + @Test + public void parseSqlStatementWithQuotesAndCommentAfter() { + String sql2 = "SELECT ':foo'/*:doo*/, :xxx FROM DUAL"; + + ParsedSql psql2 = NamedParameterUtils.parseSqlStatement(sql2); + + assertThat(psql2.getTotalParameterCount()).isEqualTo(1); + assertThat(psql2.getParameterNames()).containsExactly("xxx"); + } + + @Test + public void shouldAllowParsingMultipleUseOfParameter() { + + String sql = "SELECT * FROM person where name = :id or lastname = :id"; + + ParsedSql parsed = NamedParameterUtils.parseSqlStatement(sql); + + assertThat(parsed.getTotalParameterCount()).isEqualTo(2); + assertThat(parsed.getNamedParameterCount()).isEqualTo(1); + assertThat(parsed.getParameterNames()).containsExactly("id", "id"); + } + + @Test + public void multipleEqualParameterReferencesBindsValueOnce() { + String sql = "SELECT * FROM person where name = :id or lastname = :id"; + + BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + sql, factory, new MapBindParameterSource( + Collections.singletonMap("id", Parameter.from("foo")))); + + assertThat(operation.toQuery()).isEqualTo( + "SELECT * FROM person where name = $0 or lastname = $0"); + + operation.bindTo(new BindTarget() { + + @Override + public void bind(String identifier, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void bind(int index, Object value) { + assertThat(index).isEqualTo(0); + assertThat(value).isEqualTo("foo"); + } + + @Override + public void bindNull(String identifier, Class type) { + throw new UnsupportedOperationException(); + } + + @Override + public void bindNull(int index, Class type) { + throw new UnsupportedOperationException(); + } + }); + } + + @Test + public void multipleEqualCollectionParameterReferencesBindsValueOnce() { + String sql = "SELECT * FROM person where name IN (:ids) or lastname IN (:ids)"; + + BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0); + + MultiValueMap bindings = new LinkedMultiValueMap<>(); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + sql, factory, new MapBindParameterSource(Collections.singletonMap("ids", + Parameter.from(Arrays.asList("foo", "bar", "baz"))))); + + assertThat(operation.toQuery()).isEqualTo( + "SELECT * FROM person where name IN ($0, $1, $2) or lastname IN ($0, $1, $2)"); + + operation.bindTo(new BindTarget() { + + @Override + public void bind(String identifier, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void bind(int index, Object value) { + assertThat(index).isIn(0, 1, 2); + assertThat(value).isIn("foo", "bar", "baz"); + + bindings.add(index, value); + } + + @Override + public void bindNull(String identifier, Class type) { + throw new UnsupportedOperationException(); + } + + @Override + public void bindNull(int index, Class type) { + throw new UnsupportedOperationException(); + } + }); + + assertThat(bindings).containsEntry(0, Collections.singletonList("foo")) // + .containsEntry(1, Collections.singletonList("bar")) // + .containsEntry(2, Collections.singletonList("baz")); + } + + @Test + public void multipleEqualParameterReferencesForAnonymousMarkersBindsValueMultipleTimes() { + String sql = "SELECT * FROM person where name = :id or lastname = :id"; + + BindMarkersFactory factory = BindMarkersFactory.anonymous("?"); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + sql, factory, new MapBindParameterSource( + Collections.singletonMap("id", Parameter.from("foo")))); + + assertThat(operation.toQuery()).isEqualTo( + "SELECT * FROM person where name = ? or lastname = ?"); + + Map bindValues = new LinkedHashMap<>(); + + operation.bindTo(new BindTarget() { + + @Override + public void bind(String identifier, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void bind(int index, Object value) { + bindValues.put(index, value); + } + + @Override + public void bindNull(String identifier, Class type) { + throw new UnsupportedOperationException(); + } + + @Override + public void bindNull(int index, Class type) { + throw new UnsupportedOperationException(); + } + }); + + assertThat(bindValues).hasSize(2).containsEntry(0, "foo").containsEntry(1, "foo"); + } + + @Test + public void multipleEqualParameterReferencesBindsNullOnce() { + String sql = "SELECT * FROM person where name = :id or lastname = :id"; + + BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0); + + PreparedOperation operation = NamedParameterUtils.substituteNamedParameters( + sql, factory, new MapBindParameterSource( + Collections.singletonMap("id", Parameter.empty(String.class)))); + + assertThat(operation.toQuery()).isEqualTo( + "SELECT * FROM person where name = $0 or lastname = $0"); + + operation.bindTo(new BindTarget() { + + @Override + public void bind(String identifier, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void bind(int index, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void bindNull(String identifier, Class type) { + throw new UnsupportedOperationException(); + } + + @Override + public void bindNull(int index, Class type) { + assertThat(index).isEqualTo(0); + assertThat(type).isEqualTo(String.class); + } + }); + } + + private String expand(ParsedSql sql) { + return NamedParameterUtils.substituteNamedParameters(sql, BIND_MARKERS, + new MapBindParameterSource()).toQuery(); + } + + private String expand(String sql) { + return NamedParameterUtils.substituteNamedParameters(sql, BIND_MARKERS, + new MapBindParameterSource()).toQuery(); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkersUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkersUnitTests.java new file mode 100644 index 000000000000..87169d4bca8d --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/AnonymousBindMarkersUnitTests.java @@ -0,0 +1,59 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for {@link AnonymousBindMarkers}. + * + * @author Mark Paluch + */ +class AnonymousBindMarkersUnitTests { + + @Test + public void shouldCreateNewBindMarkers() { + BindMarkersFactory factory = BindMarkersFactory.anonymous("?"); + + BindMarkers bindMarkers1 = factory.create(); + BindMarkers bindMarkers2 = factory.create(); + + assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("?"); + assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("?"); + } + + @Test + public void shouldBindByIndex() { + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.anonymous("?").create(); + + BindMarker first = bindMarkers.next(); + BindMarker second = bindMarkers.next(); + + second.bind(bindTarget, "foo"); + first.bindNull(bindTarget, Object.class); + + verify(bindTarget).bindNull(0, Object.class); + verify(bindTarget).bind(1, "foo"); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/BindingsUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/BindingsUnitTests.java new file mode 100644 index 000000000000..0bfef03742d4 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/BindingsUnitTests.java @@ -0,0 +1,146 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + + +/** + * Unit tests for {@link Bindings}. + * + * @author Mark Paluch + */ +class BindingsUnitTests { + + BindMarkersFactory markersFactory = BindMarkersFactory.indexed("$", 1); + BindTarget bindTarget = mock(BindTarget.class); + + @Test + void shouldCreateBindings() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + bindings.bind(bindings.nextMarker(), "foo"); + bindings.bindNull(bindings.nextMarker(), String.class); + + assertThat(bindings).hasSize(2); + } + + @Test + void shouldApplyValueBinding() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + bindings.bind(bindings.nextMarker(), "foo"); + bindings.apply(bindTarget); + + verify(bindTarget).bind(0, "foo"); + } + + @Test + void shouldApplySimpleValueBinding() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + BindMarker marker = bindings.bind("foo"); + bindings.apply(bindTarget); + + assertThat(marker.getPlaceholder()).isEqualTo("$1"); + verify(bindTarget).bind(0, "foo"); + } + + @Test + void shouldApplyNullBinding() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + bindings.bindNull(bindings.nextMarker(), String.class); + + bindings.apply(bindTarget); + + verify(bindTarget).bindNull(0, String.class); + } + + @Test + void shouldApplySimpleNullBinding() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + BindMarker marker = bindings.bindNull(String.class); + bindings.apply(bindTarget); + + assertThat(marker.getPlaceholder()).isEqualTo("$1"); + verify(bindTarget).bindNull(0, String.class); + } + + @Test + void shouldConsumeBindings() { + MutableBindings bindings = new MutableBindings(markersFactory.create()); + + bindings.bind(bindings.nextMarker(), "foo"); + bindings.bindNull(bindings.nextMarker(), String.class); + + AtomicInteger counter = new AtomicInteger(); + + bindings.forEach(binding -> { + + if (binding.hasValue()) { + counter.incrementAndGet(); + assertThat(binding.getValue()).isEqualTo("foo"); + assertThat(binding.getBindMarker().getPlaceholder()).isEqualTo("$1"); + } + + if (binding.isNull()) { + counter.incrementAndGet(); + + assertThat(((Bindings.NullBinding) binding).getValueType()).isEqualTo(String.class); + assertThat(binding.getBindMarker().getPlaceholder()).isEqualTo("$2"); + } + }); + + assertThat(counter).hasValue(2); + } + + @Test + void shouldMergeBindings() { + BindMarkers markers = markersFactory.create(); + + BindMarker shared = markers.next(); + BindMarker leftMarker = markers.next(); + List left = new ArrayList<>(); + left.add(new Bindings.NullBinding(shared, String.class)); + left.add(new Bindings.ValueBinding(leftMarker, "left")); + + BindMarker rightMarker = markers.next(); + List right = new ArrayList<>(); + left.add(new Bindings.ValueBinding(shared, "override")); + left.add(new Bindings.ValueBinding(rightMarker, "right")); + + Bindings merged = Bindings.merge(new Bindings(left), new Bindings(right)); + + assertThat(merged).hasSize(3); + + merged.apply(bindTarget); + verify(bindTarget).bind(0, "override"); + verify(bindTarget).bind(1, "left"); + verify(bindTarget).bind(2, "right"); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/IndexedBindMarkersUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/IndexedBindMarkersUnitTests.java new file mode 100644 index 000000000000..9bc389d71887 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/IndexedBindMarkersUnitTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for {@link IndexedBindMarkers}. + * + * @author Mark Paluch + */ +class IndexedBindMarkersUnitTests { + + @Test + void shouldCreateNewBindMarkers() { + BindMarkersFactory factory = BindMarkersFactory.indexed("$", 0); + + BindMarkers bindMarkers1 = factory.create(); + BindMarkers bindMarkers2 = factory.create(); + + assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("$0"); + assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("$0"); + } + + @Test + void shouldCreateNewBindMarkersWithOffset() { + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 1).create(); + + BindMarker first = bindMarkers.next(); + first.bind(bindTarget, "foo"); + + BindMarker second = bindMarkers.next(); + second.bind(bindTarget, "bar"); + + assertThat(first.getPlaceholder()).isEqualTo("$1"); + assertThat(second.getPlaceholder()).isEqualTo("$2"); + verify(bindTarget).bind(0, "foo"); + verify(bindTarget).bind(1, "bar"); + } + + @Test + void nextShouldIncrementBindMarker() { + String[] prefixes = { "$", "?" }; + + for (String prefix : prefixes) { + + BindMarkers bindMarkers = BindMarkersFactory.indexed(prefix, 0).create(); + + BindMarker marker1 = bindMarkers.next(); + BindMarker marker2 = bindMarkers.next(); + + assertThat(marker1.getPlaceholder()).isEqualTo(prefix + "0"); + assertThat(marker2.getPlaceholder()).isEqualTo(prefix + "1"); + } + } + + @Test + void bindValueShouldBindByIndex() { + + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 0).create(); + + bindMarkers.next().bind(bindTarget, "foo"); + bindMarkers.next().bind(bindTarget, "bar"); + + verify(bindTarget).bind(0, "foo"); + verify(bindTarget).bind(1, "bar"); + } + + @Test + void bindNullShouldBindByIndex() { + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.indexed("$", 0).create(); + + bindMarkers.next(); // ignore + bindMarkers.next().bindNull(bindTarget, Integer.class); + + verify(bindTarget).bindNull(1, Integer.class); + } + +} diff --git a/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/NamedBindMarkersUnitTests.java b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/NamedBindMarkersUnitTests.java new file mode 100644 index 000000000000..d83691e5b2d7 --- /dev/null +++ b/spring-r2dbc/src/test/java/org/springframework/r2dbc/core/binding/NamedBindMarkersUnitTests.java @@ -0,0 +1,115 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core.binding; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Unit tests for {@link NamedBindMarkers}. + * + * @author Mark Paluch + */ +class NamedBindMarkersUnitTests { + + @Test + void shouldCreateNewBindMarkers() { + BindMarkersFactory factory = BindMarkersFactory.named("@", "p", 32); + + BindMarkers bindMarkers1 = factory.create(); + BindMarkers bindMarkers2 = factory.create(); + + assertThat(bindMarkers1.next().getPlaceholder()).isEqualTo("@p0"); + assertThat(bindMarkers2.next().getPlaceholder()).isEqualTo("@p0"); + } + + @ParameterizedTest + @ValueSource(strings = { "$", "?" }) + void nextShouldIncrementBindMarker(String prefix) { + BindMarkers bindMarkers = BindMarkersFactory.named(prefix, "p", 32).create(); + + BindMarker marker1 = bindMarkers.next(); + BindMarker marker2 = bindMarkers.next(); + + assertThat(marker1.getPlaceholder()).isEqualTo(prefix + "p0"); + assertThat(marker2.getPlaceholder()).isEqualTo(prefix + "p1"); + } + + @Test + void nextShouldConsiderNameHint() { + BindMarkers bindMarkers = BindMarkersFactory.named("@", "x", 32).create(); + + BindMarker marker1 = bindMarkers.next("foo1bar"); + BindMarker marker2 = bindMarkers.next(); + + assertThat(marker1.getPlaceholder()).isEqualTo("@x0foo1bar"); + assertThat(marker2.getPlaceholder()).isEqualTo("@x1"); + } + + @Test + void nextShouldConsiderFilteredNameHint() { + BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32, + s -> s.chars().filter(Character::isAlphabetic).collect(StringBuilder::new, + StringBuilder::appendCodePoint, StringBuilder::append).toString()).create(); + + BindMarker marker1 = bindMarkers.next("foo1.bar?"); + BindMarker marker2 = bindMarkers.next(); + + assertThat(marker1.getPlaceholder()).isEqualTo("@p0foobar"); + assertThat(marker2.getPlaceholder()).isEqualTo("@p1"); + } + + @Test + void nextShouldConsiderNameLimit() { + BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 10).create(); + + BindMarker marker1 = bindMarkers.next("123456789"); + + assertThat(marker1.getPlaceholder()).isEqualTo("@p012345678"); + } + + @Test + void bindValueShouldBindByName() { + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32).create(); + + bindMarkers.next().bind(bindTarget, "foo"); + bindMarkers.next().bind(bindTarget, "bar"); + + verify(bindTarget).bind("p0", "foo"); + verify(bindTarget).bind("p1", "bar"); + } + + @Test + void bindNullShouldBindByName() { + BindTarget bindTarget = mock(BindTarget.class); + + BindMarkers bindMarkers = BindMarkersFactory.named("@", "p", 32).create(); + + bindMarkers.next(); // ignore + bindMarkers.next().bindNull(bindTarget, Integer.class); + + verify(bindTarget).bindNull("p1", Integer.class); + } + +} diff --git a/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensionsTests.kt b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensionsTests.kt new file mode 100644 index 000000000000..24f672c5cd4e --- /dev/null +++ b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/DatabaseClientExtensionsTests.kt @@ -0,0 +1,105 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core + +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import reactor.core.publisher.Mono + +/** + * Unit tests for [DatabaseClient] extensions. + * + * @author Sebastien Deleuze + * @author Jonas Bark + * @author Mark Paluch + */ +class DatabaseClientExtensionsTests { + + @Test + fun bindByIndexShouldBindValue() { + val spec = mockk() + every { spec.bind(eq(0), any()) } returns spec + + runBlocking { + spec.bind(0, "foo") + } + + verify { + spec.bind(0, Parameter.fromOrEmpty("foo", String::class.java)) + } + } + + @Test + fun bindByIndexShouldBindNull() { + val spec = mockk() + every { spec.bind(eq(0), any()) } returns spec + + runBlocking { + spec.bind(0, null) + } + + verify { + spec.bind(0, Parameter.empty(String::class.java)) + } + } + + @Test + fun bindByNameShouldBindValue() { + val spec = mockk() + every { spec.bind(eq("field"), any()) } returns spec + + runBlocking { + spec.bind("field", "foo") + } + + verify { + spec.bind("field", Parameter.fromOrEmpty("foo", String::class.java)) + } + } + + @Test + fun bindByNameShouldBindNull() { + val spec = mockk() + every { spec.bind(eq("field"), any()) } returns spec + + runBlocking { + spec.bind("field", null) + } + + verify { + spec.bind("field", Parameter.empty(String::class.java)) + } + } + + @Test + fun genericExecuteSpecAwait() { + val spec = mockk() + every { spec.then() } returns Mono.empty() + + runBlocking { + spec.await() + } + + verify { + spec.then() + } + } + +} diff --git a/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensionsTests.kt b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensionsTests.kt new file mode 100644 index 000000000000..3146db956cc8 --- /dev/null +++ b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/RowsFetchSpecExtensionsTests.kt @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.r2dbc.core + +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.junit.jupiter.api.Test +import org.springframework.dao.EmptyResultDataAccessException +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono + +/** + * Unit tests for [RowsFetchSpec] extensions. + * + * @author Sebastien Deleuze + * @author Mark Paluch + */ +class RowsFetchSpecExtensionsTests { + + @Test + fun awaitOneWithValue() { + val spec = mockk>() + every { spec.one() } returns Mono.just("foo") + + runBlocking { + assertThat(spec.awaitOne()).isEqualTo("foo") + } + + verify { + spec.one() + } + } + + @Test + fun awaitOneWithNull() { + val spec = mockk>() + every { spec.one() } returns Mono.empty() + + assertThatExceptionOfType(EmptyResultDataAccessException::class.java).isThrownBy { + runBlocking { spec.awaitOne() } + } + + verify { + spec.one() + } + } + + @Test + fun awaitOneOrNullWithValue() { + val spec = mockk>() + every { spec.one() } returns Mono.just("foo") + + runBlocking { + assertThat(spec.awaitOneOrNull()).isEqualTo("foo") + } + + verify { + spec.one() + } + } + + @Test + fun awaitOneOrNullWithNull() { + val spec = mockk>() + every { spec.one() } returns Mono.empty() + + runBlocking { + assertThat(spec.awaitOneOrNull()).isNull() + } + + verify { + spec.one() + } + } + + @Test + fun awaitFirstWithValue() { + val spec = mockk>() + every { spec.first() } returns Mono.just("foo") + + runBlocking { + assertThat(spec.awaitFirst()).isEqualTo("foo") + } + + verify { + spec.first() + } + } + + @Test + fun awaitFirstWithNull() { + val spec = mockk>() + every { spec.first() } returns Mono.empty() + + assertThatExceptionOfType(EmptyResultDataAccessException::class.java).isThrownBy { + runBlocking { spec.awaitFirst() } + } + + verify { + spec.first() + } + } + + @Test + fun awaitFirstOrNullWithValue() { + val spec = mockk>() + every { spec.first() } returns Mono.just("foo") + + runBlocking { + assertThat(spec.awaitFirstOrNull()).isEqualTo("foo") + } + + verify { + spec.first() + } + } + + @Test + fun awaitFirstOrNullWithNull() { + val spec = mockk>() + every { spec.first() } returns Mono.empty() + + runBlocking { + assertThat(spec.awaitFirstOrNull()).isNull() + } + + verify { + spec.first() + } + } + + @Test + @ExperimentalCoroutinesApi + fun allAsFlow() { + val spec = mockk>() + every { spec.all() } returns Flux.just("foo", "bar", "baz") + + runBlocking { + assertThat(spec.flow().toList()).contains("foo", "bar", "baz") + } + + verify { + spec.all() + } + } + +} diff --git a/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensionsTests.kt b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensionsTests.kt new file mode 100644 index 000000000000..eb7058c846da --- /dev/null +++ b/spring-r2dbc/src/test/kotlin/org/springframework/r2dbc/core/UpdatedRowsFetchSpecExtensionsTests.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.r2dbc.core + +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import reactor.core.publisher.Mono + +/** + * Unit tests for [UpdatedRowsFetchSpec] extensions. + * + * @author Fred Montariol + */ +class UpdatedRowsFetchSpecExtensionsTests { + + @Test + fun awaitRowsUpdatedWithValue() { + val spec = mockk() + every { spec.rowsUpdated() } returns Mono.just(42) + + runBlocking { + assertThat(spec.awaitRowsUpdated()).isEqualTo(42) + } + + verify { + spec.rowsUpdated() + } + } + +} diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema-failed-drop-comments.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema-failed-drop-comments.sql new file mode 100644 index 000000000000..8ce888986b1b --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema-failed-drop-comments.sql @@ -0,0 +1,5 @@ +-- Failed DROP can be ignored if necessary +drop table users; + +-- Create the test table +create table users (last_name varchar(50) not null); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema.sql new file mode 100644 index 000000000000..4de3841ec12b --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-schema.sql @@ -0,0 +1,3 @@ +drop table users if exists; + +create table users (last_name varchar(50) not null); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-endings.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-endings.sql new file mode 100644 index 000000000000..78a82189b06d --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-endings.sql @@ -0,0 +1,2 @@ +insert into users (last_name) values ('Heisenberg')@@ +insert into users (last_name) values ('Jesse')@@ diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-escaped-literal.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-escaped-literal.sql new file mode 100644 index 000000000000..3ba33ccacd77 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-escaped-literal.sql @@ -0,0 +1 @@ +insert into users (last_name) values ('''Heisenberg'''); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-h2.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-h2.sql new file mode 100644 index 000000000000..a62e920a2ccd --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-h2.sql @@ -0,0 +1 @@ +INSERT INTO users(first_name, last_name) values('Walter', 'White'); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multi-newline.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multi-newline.sql new file mode 100644 index 000000000000..6239f6adcc10 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multi-newline.sql @@ -0,0 +1,5 @@ +insert into users (last_name) +values ('Walter') + +insert into users (last_name) +values ('Jesse') diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multiple.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multiple.sql new file mode 100644 index 000000000000..ea185476c7c0 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-multiple.sql @@ -0,0 +1,2 @@ +insert into users (last_name) values ('Heisenberg'); +insert into users (last_name) values ('Jesse'); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-mysql-escaped-literal.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-mysql-escaped-literal.sql new file mode 100644 index 000000000000..dae44d40e065 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data-mysql-escaped-literal.sql @@ -0,0 +1 @@ +insert into users (last_name) values ('\$Heisenberg\$'); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data.sql new file mode 100644 index 000000000000..85673705f7a5 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/db-test-data.sql @@ -0,0 +1 @@ +insert into users (last_name) values ('Heisenberg'); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments-and-leading-tabs.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments-and-leading-tabs.sql new file mode 100644 index 000000000000..ddb67f0cf1f3 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments-and-leading-tabs.sql @@ -0,0 +1,9 @@ +-- The next comment line starts with a tab. + -- x, y, z... + +insert into customer (id, name) +values (1, 'Walter White'); + -- This is also a comment with a leading tab. +insert into orders(id, order_date, customer_id) values (1, '2013-06-08', 1); + -- This is also a comment with a leading tab, a space, and a tab. +insert into orders(id, order_date, customer_id) values (2, '2013-06-08', 1); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments.sql new file mode 100644 index 000000000000..82483ca4d369 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-comments.sql @@ -0,0 +1,16 @@ +-- The next comment line has no text after the '--' prefix. +-- +-- The next comment line starts with a space. + -- x, y, z... + +insert into customer (id, name) +values (1, 'Rod; Johnson'), (2, 'Adrian Collier'); +-- This is also a comment. +insert into orders(id, order_date, customer_id) +values (1, '2008-01-02', 2); +insert into orders(id, order_date, customer_id) values (1, '2008-01-02', 2); +INSERT INTO persons( person_id-- + , name) +VALUES( 1 -- person_id + , 'Name' --name +);-- \ No newline at end of file diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-comments.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-comments.sql new file mode 100644 index 000000000000..8cfa6d438af7 --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-comments.sql @@ -0,0 +1,17 @@ +/* This is a multi line comment + * The next comment line has no text + + * The next comment line starts with a space. + * x, y, z... + */ + +INSERT INTO users(first_name, last_name) VALUES('Walter', 'White'); +-- This is also a comment. +/* + * Let's add another comment + * that covers multiple lines + */INSERT INTO +users(first_name, last_name) +VALUES( 'Jesse' -- first_name + , 'Pinkman' -- last_name +);-- diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-nested-comments.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-nested-comments.sql new file mode 100644 index 000000000000..5a3d3a1363fb --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/test-data-with-multi-line-nested-comments.sql @@ -0,0 +1,23 @@ +/* This is a multi line comment + * The next comment line has no text + + * The next comment line starts with a space. + * x, y, z... + */ + +INSERT INTO users(first_name, last_name) VALUES('Walter', 'White'); +-- This is also a comment. +/*------------------------------------------- +-- A fancy multi-line comments that puts +-- single line comments inside of a multi-line +-- comment block. +Moreover, the block comment end delimiter +appears on a line that can potentially also +be a single-line comment if we weren't +already inside a multi-line comment run. +-------------------------------------------*/ + INSERT INTO +users(first_name, last_name) -- This is a single line comment containing the block-end-comment sequence here */ but it's still a single-line comment +VALUES( 'Jesse' -- first_name + , 'Pinkman' -- last_name +);-- diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-data.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-data.sql new file mode 100644 index 000000000000..a6aa7838526e --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-data.sql @@ -0,0 +1,3 @@ +INSERT INTO +users(first_name, last_name) +values('Sam', 'Brannen'); diff --git a/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-schema.sql b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-schema.sql new file mode 100644 index 000000000000..80ffe23da9ad --- /dev/null +++ b/spring-r2dbc/src/test/resources/org/springframework/r2dbc/connection/init/users-schema.sql @@ -0,0 +1,7 @@ +DROP TABLE users IF EXISTS; + +CREATE TABLE users ( + id INTEGER NOT NULL IDENTITY, + first_name VARCHAR(50) NOT NULL, + last_name VARCHAR(50) NOT NULL +); diff --git a/src/docs/asciidoc/data-access.adoc b/src/docs/asciidoc/data-access.adoc index 4086e4ddd386..424205f22fe9 100644 --- a/src/docs/asciidoc/data-access.adoc +++ b/src/docs/asciidoc/data-access.adoc @@ -6536,6 +6536,647 @@ Ensuring that the database initializer is initialized first can also be easy. So +[[r2dbc]] +== Data Access with R2DBC + +https://r2dbc.io[R2DBC] ("Reactive Relational Database Connectivity") is a community-driven +specification effort to standardize access to SQL databases using reactive patterns. + + +[[r2dbc-packages]] +=== Package Hierarchy + +The Spring Framework's R2DBC abstraction framework consists of two different packages: + +* `core`: The `org.springframework.r2dbc.core` package contains the `DatabaseClient` +class plus a variety of related classes. See <>. + +* `connection`: The `org.springframework.r2dbc.connection` package contains a utility class +for easy `ConnectionFactory` access and various simple `ConnectionFactory` implementations +that you can use for testing and running unmodified R2DBC. See <>. + + +[[r2dbc-core]] +=== Using the R2DBC Core Classes to Control Basic R2DBC Processing and Error Handling + +This section covers how to use the R2DBC core classes to control basic R2DBC processing, +including error handling. It includes the following topics: + +* <> +* <> +* <> +* <> +* <> +* <> + +[[r2dbc-DatabaseClient]] +==== Using `DatabaseClient` + +`DatabaseClient` is the central class in the R2DBC core package. It handles the +creation and release of resources, which helps to avoid common errors, such as +forgetting to close the connection. It performs the basic tasks of the core R2DBC +workflow (such as statement creation and execution), leaving application code to provide +SQL and extract results. The `DatabaseClient` class: + +* Runs SQL queries +* Update statements and stored procedure calls +* Performs iteration over `Result` instances +* Catches R2DBC exceptions and translates them to the generic, more informative, exception +hierarchy defined in the `org.springframework.dao` package. (See <>.) + +The client has a functional, fluent API using reactive types for declarative composition. + +When you use the `DatabaseClient` for your code, you need only to implement +`java.util.function` interfaces, giving them a clearly defined contract. +Given a `Connection` provided by the `DatabaseClient` class, a `Function` +callback creates a `Publisher`. The same is true for mapping functions that +extract a `Row` result. + +You can use `DatabaseClient` within a DAO implementation through direct instantiation +with a `ConnectionFactory` reference, or you can configure it in a Spring IoC container +and give it to DAOs as a bean reference. + +The simplest way to create a `DatabaseClient` object is through a static factory method, as follows: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + DatabaseClient client = DatabaseClient.create(connectionFactory); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val client = DatabaseClient.create(connectionFactory) +---- + +NOTE: The `ConnectionFactory` should always be configured as a bean in the Spring IoC +container. + +The preceding method creates a `DatabaseClient` with default settings. + +You can also obtain a `Builder` instance from `DatabaseClient.builder()`. +You can customize the client by calling the following methods: + +* `….bindMarkers(…)`: Supply a specific `BindMarkersFactory` to configure named +parameter to database bind marker translation. +* `….executeFunction(…)`: Set the `ExecuteFunction` how `Statement` objects get + executed. +* `….namedParameters(false)`: Disable named parameter expansion. Enabled by default. + +TIP: Dialects are resolved by {api-spring-framework}/r2dbc/core/binding/BindMarkersFactoryResolver.html[`BindMarkersFactoryResolver`] + from a `ConnectionFactory`, typically by inspecting `ConnectionFactoryMetadata`. + + +You can let Spring auto-discover your `BindMarkersFactory` by registering a +class that implements `org.springframework.r2dbc.core.binding.BindMarkersFactoryResolver$BindMarkerFactoryProvider` +through `META-INF/spring.factories`. +`BindMarkersFactoryResolver` discovers bind marker provider implementations from +the class path using Spring's `SpringFactoriesLoader`. + + + +Currently supported databases are: + +* H2 +* MariaDB +* Microsoft SQL Server +* MySQL +* Postgres + +All SQL issued by this class is logged at the `DEBUG` level under the category +corresponding to the fully qualified class name of the client instance (typically +`DefaultDatabaseClient`). Additionally, each execution registers a checkpoint in +the reactive sequence to aid debugging. + +The following sections provide some examples of `DatabaseClient` usage. These examples +are not an exhaustive list of all of the functionality exposed by the `DatabaseClient`. +See the attendant {api-spring-framework}/r2dbc/core/DatabaseClient.html[javadoc] for that. + +[[r2dbc-DatabaseClient-examples-statement]] +===== Executing Statements + +`DatabaseClient` provides the basic functionality of running a statement. +The following example shows what you need to include for minimal but fully functional +code that creates a new table: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Mono completion = client.sql("CREATE TABLE person (id VARCHAR(255) PRIMARY KEY, name VARCHAR(255), age INTEGER);") + .then(); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + client.sql("CREATE TABLE person (id VARCHAR(255) PRIMARY KEY, name VARCHAR(255), age INTEGER);") + .await() +---- + +`DatabaseClient` is designed for convenient, fluent usage. +It exposes intermediate, continuation, and terminal methods at each stage of the +execution specification. The preceding example above uses `then()` to return a completion +`Publisher` that completes as soon as the query (or queries, if the SQL query contains +multiple statements) completes. + +NOTE: `execute(…)` accepts either the SQL query string or a query `Supplier` +to defer the actual query creation until execution. + +[[r2dbc-DatabaseClient-examples-query]] +===== Querying (`SELECT`) + +SQL queries can return values through `Row` objects or the number of affected rows. +`DatabaseClient` can return the number of updated rows or the rows themselves, +depending on the issued query. + +The following query gets the `id` and `name` columns from a table: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Mono> first = client.sql("SELECT id, name FROM person") + .fetch().first(); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val first = client.sql("SELECT id, name FROM person") + .fetch().awaitFirst() +---- + +The following query uses a bind variable: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Mono> first = client.sql("SELECT id, name FROM person WHERE first_name = :fn") + .bind("fn", "Joe") + .fetch().first(); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val first = client.sql("SELECT id, name FROM person WHERE WHERE first_name = :fn") + .bind("fn", "Joe") + .fetch().awaitFirst() +---- + +You might have noticed the use of `fetch()` in the example above. `fetch()` is a +continuation operator that lets you specify how much data you want to consume. + +Calling `first()` returns the first row from the result and discards remaining rows. +You can consume data with the following operators: + +* `first()` return the first row of the entire result. Its Kotlin Coroutine variant +is named `awaitFirst()` for non-nullable return values and `awaitFirstOrNull()` +if the value is optional. +* `one()` returns exactly one result and fails if the result contains more rows. +Using Kotlin Coroutines, `awaitOne()` for exactly one value or `awaitOneOrNull()` +if the value may be `null`. +* `all()` returns all rows of the result. When using Kotlin Coroutines, use `flow()`. +* `rowsUpdated()` returns the number of affected rows (`INSERT`/`UPDATE`/`DELETE` +count). Its Kotlin Coroutine variant is named `awaitRowsUpdated()`. + +Without specifying further mapping details, queries return tabular results +as `Map` whose keys are case-insensitive column names that map to their column value. + +You can take control over result mapping by supplying a `Function` that gets +called for each `Row` so it can can return arbitrary values (singular values, +collections and maps, and objects). + +The following example extracts the `id` column and emits its value: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Flux names = client.sql("SELECT name FROM person") + .map(row -> row.get("id", String.class)) + .all(); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val names = client.sql("SELECT name FROM person") + .map{ it.get("id", String.class) } + .flow() +---- + + +[[r2dbc-DatabaseClient-mapping-null]] +.What about `null`? +**** +Relational database results can contain `null` values. +The Reactive Streams specification forbids the emission of `null` values. +That requirement mandates proper `null` handling in the extractor function. +While you can obtain `null` values from a `Row`, you must not emit a `null` +value. You must wrap any `null` values in an object (for example, `Optional` +for singular values) to make sure a `null` value is never returned directly +by your extractor function. +**** + +[[r2dbc-DatabaseClient-examples-update]] +===== Updating (`INSERT`, `UPDATE`, and `DELETE`) with `DatabaseClient` + +The only difference of modifying statements is that these statements typically +do not return tabular data so you use `rowsUpdated()` to consume results. + +The following example shows an `UPDATE` statement that returns the number +of updated rows: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Mono affectedRows = client.sql("UPDATE person SET first_name = :fn") + .bind("fn", "Joe") + .fetch().rowsUpdated(); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val affectedRows = client.sql("UPDATE person SET first_name = :fn") + .bind("fn", "Joe") + .fetch().awaitRowsUpdated() +---- + +[[r2dbc-DatabaseClient-named-parameters]] +===== Binding Values to Queries + +A typical application requires parameterized SQL statements to select or +update rows according to some input. These are typically `SELECT` statements +constrained by a `WHERE` clause or `INSERT` and `UPDATE` statements that accept +input parameters. Parameterized statements bear the risk of SQL injection if +parameters are not escaped properly. `DatabaseClient` leverages R2DBC's +`bind` API to eliminate the risk of SQL injection for query parameters. +You can provide a parameterized SQL statement with the `execute(…)` operator +and bind parameters to the actual `Statement`. Your R2DBC driver then executes +the statement by using prepared statements and parameter substitution. + +Parameter binding supports two binding strategies: + +* By Index, using zero-based parameter indexes. +* By Name, using the placeholder name. + +The following example shows parameter binding for a query: + +==== +[source,java] +---- +db.sql("INSERT INTO person (id, name, age) VALUES(:id, :name, :age)") + .bind("id", "joe") + .bind("name", "Joe") + .bind("age", 34); +---- +==== + +.R2DBC Native Bind Markers +**** +R2DBC uses database-native bind markers that depend on the actual database vendor. +As an example, Postgres uses indexed markers, such as `$1`, `$2`, `$n`. +Another example is SQL Server, which uses named bind markers prefixed with `@`. + +This is different from JDBC, which requires `?` as bind markers. +In JDBC, the actual drivers translate `?` bind markers to database-native +markers as part of their statement execution. + +Spring Framework's R2DBC support lets you use native bind markers or named bind +markers with the `:name` syntax. + +Named parameter support leverages a `BindMarkersFactory` instance to expand named +parameters to native bind markers at the time of query execution, which gives you +a certain degree of query portability across various database vendors. +**** + +The query-preprocessor unrolls named `Collection` parameters into a series of bind +markers to remove the need of dynamic query creation based on the number of arguments. +Nested object arrays are expanded to allow usage of (for example) select lists. + +Consider the following query: + +[source,sql] +---- +SELECT id, name, state FROM table WHERE (name, age) IN (('John', 35), ('Ann', 50)) +---- + +The preceding query can be parametrized and executed as follows: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + List tuples = new ArrayList<>(); + tuples.add(new Object[] {"John", 35}); + tuples.add(new Object[] {"Ann", 50}); + + client.sql("SELECT id, name, state FROM table WHERE (name, age) IN (:tuples)") + .bind("tuples", tuples); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val tuples: MutableList> = ArrayList() + tuples.add(arrayOf("John", 35)) + tuples.add(arrayOf("Ann", 50)) + + client.sql("SELECT id, name, state FROM table WHERE (name, age) IN (:tuples)") + .bind("tuples", tuples) +---- + +NOTE: Usage of select lists is vendor-dependent. + +The following example shows a simpler variant using `IN` predicates: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + client.sql("SELECT id, name, state FROM table WHERE age IN (:ages)") + .bind("ages", Arrays.asList(35, 50)); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val tuples: MutableList> = ArrayList() + tuples.add(arrayOf("John", 35)) + tuples.add(arrayOf("Ann", 50)) + + client.sql("SELECT id, name, state FROM table WHERE age IN (:ages)") + .bind("tuples", arrayOf(35, 50)) +---- + +[[r2dbc-DatbaseClient-filter]] +===== Statement Filters + +Sometimes it you need to fine-tune options on the actual `Statement` +before it gets executed. Register a `Statement` filter +(`StatementFilterFunction`) through `DatabaseClient` to intercept and +modify statements in their execution, as the following example shows: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter((s, next) -> next.execute(s.returnGeneratedValues("id"))) + .bind("name", …) + .bind("state", …); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter { s: Statement, next: ExecuteFunction -> next.execute(s.returnGeneratedValues("id")) } + .bind("name", …) + .bind("state", …) +---- + +`DatabaseClient` exposes also simplified `filter(…)` overload accepting `Function`: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter(statement -> s.returnGeneratedValues("id")); + + client.sql("SELECT id, name, state FROM table") + .filter(statement -> s.fetchSize(25)); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter { statement -> s.returnGeneratedValues("id") } + + client.sql("SELECT id, name, state FROM table") + .filter { statement -> s.fetchSize(25) } +---- + +`StatementFilterFunction` implementations allow filtering of the executed +`Statement` and filtering of `Result` objects. + +[[r2dbc-DatabaseClient-idioms]] +===== `DatabaseClient` Best Practices + +Instances of the `DatabaseClient` class are thread-safe, once configured. This is +important because it means that you can configure a single instance of a `DatabaseClient` +and then safely inject this shared reference into multiple DAOs (or repositories). +The `DatabaseClient` is stateful, in that it maintains a reference to a `ConnectionFactory`, +but this state is not conversational state. + +A common practice when using the `DatabaseClient` class is to configure a `ConnectionFactory` +in your Spring configuration file and then dependency-inject +that shared `ConnectionFactory` bean into your DAO classes. The `DatabaseClient` is created in +the setter for the `ConnectionFactory`. This leads to DAOs that resemble the following: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + public class R2dbcCorporateEventDao implements CorporateEventDao { + + private DatabaseClient databaseClient; + + public void setConnectionFactory(ConnectionFactory connectionFactory) { + this.databaseClient = DatabaseClient.create(connectionFactory); + } + + // R2DBC-backed implementations of the methods on the CorporateEventDao follow... + } +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + class R2dbcCorporateEventDao(connectionFactory: ConnectionFactory) : CorporateEventDao { + + private val databaseClient = DatabaseClient.create(connectionFactory) + + // R2DBC-backed implementations of the methods on the CorporateEventDao follow... + } +---- + +An alternative to explicit configuration is to use component-scanning and annotation +support for dependency injection. In this case, you can annotate the class with `@Component` +(which makes it a candidate for component-scanning) and annotate the `ConnectionFactory` setter +method with `@Autowired`. The following example shows how to do so: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + @Component // <1> + public class R2dbcCorporateEventDao implements CorporateEventDao { + + private DatabaseClient databaseClient; + + @Autowired // <2> + public void setConnectionFactory(ConnectionFactory connectionFactory) { + this.databaseClient = DatabaseClient.create(connectionFactory); // <3> + } + + // R2DBC-backed implementations of the methods on the CorporateEventDao follow... + } +---- +<1> Annotate the class with `@Component`. +<2> Annotate the `ConnectionFactory` setter method with `@Autowired`. +<3> Create a new `DatabaseClient` with the `ConnectionFactory`. + +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + @Component // <1> + class R2dbcCorporateEventDao(connectionFactory: ConnectionFactory) : CorporateEventDao { // <2> + + private val databaseClient = DatabaseClient(connectionFactory) // <3> + + // R2DBC-backed implementations of the methods on the CorporateEventDao follow... + } +---- +<1> Annotate the class with `@Component`. +<2> Constructor injection of the `ConnectionFactory`. +<3> Create a new `DatabaseClient` with the `ConnectionFactory`. + +Regardless of which of the above template initialization styles you choose to use (or +not), it is seldom necessary to create a new instance of a `DatabaseClient` class each +time you want to run SQL. Once configured, a `DatabaseClient` instance is thread-safe. +If your application accesses multiple +databases, you may want multiple `DatabaseClient` instances, which requires multiple +`ConnectionFactory` and, subsequently, multiple differently configured `DatabaseClient` +instances. + +[[r2dbc-auto-generated-keys]] +== Retrieving Auto-generated Keys + +`INSERT` statements may generate keys when inserting rows into a table +that defines an auto-increment or identity column. To get full control over +the column name to generate, simply register a `StatementFilterFunction` that +requests the generated key for the desired column. + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + Mono generatedId = client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter(statement -> s.returnGeneratedValues("id")) + .map(row -> row.get("id", Integer.class)) + .first(); + + // generatedId emits the generated key once the INSERT statement has finished +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val generatedId = client.sql("INSERT INTO table (name, state) VALUES(:name, :state)") + .filter { statement -> s.returnGeneratedValues("id") } + .map { row -> row.get("id", Integer.class) } + .awaitOne() + + // generatedId emits the generated key once the INSERT statement has finished +---- + + +[[r2dbc-connections]] +=== Controlling Database Connections + +This section covers: + +* <> +* <> +* <> +* <> +* <> + + +[[r2dbc-ConnectionFactory]] +==== Using `ConnectionFactory` + +Spring obtains an R2DBC connection to the database through a `ConnectionFactory`. +A `ConnectionFactory` is part of the R2DBC specification and is a common entry-point +for drivers. It lets a container or a framework hide connection pooling +and transaction management issues from the application code. As a developer, +you need not know details about how to connect to the database. That is the +responsibility of the administrator who sets up the `ConnectionFactory`. You +most likely fill both roles as you develop and test code, but you do not +necessarily have to know how the production data source is configured. + +When you use Spring's R2DBC layer, you can can configure your own with a +connection pool implementation provided by a third party. A popular +implementation is R2DBC Pool (`r2dbc-pool`). Implementations in the Spring +distribution are meant only for testing purposes and do not provide pooling. + +To configure a `ConnectionFactory`: + +. Obtain a connection with `ConnectionFactory` as you typically obtain an R2DBC `ConnectionFactory`. +. Provide an R2DBC URL +(See the documentation for your driver for the correct value). + +The following example shows how to configure a `ConnectionFactory`: + +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +.Java +---- + ConnectionFactory factory = ConnectionFactories.get("r2dbc:h2:mem:///test?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE"); +---- +[source,kotlin,indent=0,subs="verbatim,quotes",role="secondary"] +.Kotlin +---- + val factory = ConnectionFactories.get("r2dbc:h2:mem:///test?options=DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE"); +---- + + +[[r2dbc-ConnectionFactoryUtils]] +==== Using `ConnectionFactoryUtils` + + +The `ConnectionFactoryUtils` class is a convenient and powerful helper class +that provides `static` methods to obtain connections from `ConnectionFactory` +and close connections (if necessary). + +It supports subscriber ``Context``-bound connections with, for example +`R2dbcTransactionManager`. + + +[[r2dbc-SingleConnectionFactory]] +==== Using `SingleConnectionFactory` + +The `SingleConnectionFactory` class is an implementation of `DelegatingConnectionFactory` +interface that wraps a single `Connection` that is not closed after each use. + +If any client code calls `close` on the assumption of a pooled connection (as when using +persistence tools), you should set the `suppressClose` property to `true`. This setting +returns a close-suppressing proxy that wraps the physical connection. Note that you can +no longer cast this to a native `Connection` or a similar object. + +`SingleConnectionFactory` is primarily a test class and may be used for specific requirements +such as pipelining if your R2DBC driver permits for such use. +In contrast to a pooled `ConnectionFactory`, it reuses the same connection all the time, avoiding +excessive creation of physical connections. + + +[[r2dbc-TransactionAwareConnectionFactoryProxy]] +==== Using `TransactionAwareConnectionFactoryProxy` + +`TransactionAwareConnectionFactoryProxy` is a proxy for a target `ConnectionFactory`. +The proxy wraps that target `ConnectionFactory` to add awareness of Spring-managed transactions. + +NOTE: Using this class is required if you use a R2DBC client that is not integrated otherwise +with Spring's R2DBC support. In this case, you can still use this client and, at +the same time, have this client participating in Spring managed transactions. It is generally +preferable to integrate a R2DBC client with proper access to `ConnectionFactoryUtils` +for resource management. + +See the {api-spring-framework}/r2dbc/connection/TransactionAwareConnectionFactoryProxy.html[`TransactionAwareConnectionFactoryProxy`] +javadoc for more details. + + +[[r2dbc-R2dbcTransactionManager]] +==== Using `R2dbcTransactionManager` + +The `R2dbcTransactionManager` class is a `ReactiveTransactionManager` implementation for +single R2DBC datasources. It binds an R2DBC connection from the specified connection factory +to the subscriber `Context`, potentially allowing for one subscriber connection for each +connection factory. + +Application code is required to retrieve the R2DBC connection through +`ConnectionFactoryUtils.getConnection(ConnectionFactory)`, instead of R2DBC's standard +`ConnectionFactory.create()`. + +All framework classes (such as `DatabaseClient`) use this strategy implicitly. +If not used with this transaction manager, the lookup strategy behaves exactly like the common one. +Thus, it can be used in any case. + +The `R2dbcTransactionManager` class supports custom isolation levels that get applied to the connection. + + [[orm]] == Object Relational Mapping (ORM) Data Access diff --git a/src/docs/asciidoc/index.adoc b/src/docs/asciidoc/index.adoc index 800c813824cf..2494865f5425 100644 --- a/src/docs/asciidoc/index.adoc +++ b/src/docs/asciidoc/index.adoc @@ -16,7 +16,7 @@ Validation, Data Binding, Type Conversion, SpEL, AOP. <> :: Mock Objects, TestContext Framework, Spring MVC Test, WebTestClient. <> :: Transactions, DAO Support, -JDBC, O/R Mapping, XML Marshalling. +JDBC, R2DBC, O/R Mapping, XML Marshalling. <> :: Spring MVC, WebSocket, SockJS, STOMP Messaging. <> :: Spring WebFlux, WebClient,