Skip to content

Handle Socket IO interruptibility #1189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,15 @@ void doMaintenance() {
RuntimeException actualException = e instanceof MongoOpenConnectionInternalException
? (RuntimeException) e.getCause()
: e;
sdamProvider.optional().ifPresent(sdam -> {
if (!silentlyComplete.test(actualException)) {
sdam.handleExceptionBeforeHandshake(SdamIssue.specific(actualException, sdam.context(newConnection)));
}
});
try {
sdamProvider.optional().ifPresent(sdam -> {
if (!silentlyComplete.test(actualException)) {
sdam.handleExceptionBeforeHandshake(SdamIssue.specific(actualException, sdam.context(newConnection)));
}
});
} catch (Exception suppressed) {
actualException.addSuppressed(suppressed);
}
throw actualException;
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ public Connection getConnection(final OperationContext operationContext) {
return OperationCountTrackingConnection.decorate(this,
connectionFactory.create(connectionPool.get(operationContext), new DefaultServerProtocolExecutor(), clusterConnectionMode));
} catch (Throwable e) {
operationEnd();
if (e instanceof MongoException) {
sdam.handleExceptionBeforeHandshake(SdamIssue.specific(e, exceptionContext));
try {
operationEnd();
if (e instanceof MongoException) {
sdam.handleExceptionBeforeHandshake(SdamIssue.specific(e, exceptionContext));
}
} catch (Exception suppressed) {
e.addSuppressed(suppressed);
}
throw e;
}
Expand All @@ -117,6 +121,8 @@ public void getConnectionAsync(final OperationContext operationContext, final Si
try {
operationEnd();
sdam.handleExceptionBeforeHandshake(SdamIssue.specific(t, exceptionContext));
} catch (Exception suppressed) {
t.addSuppressed(suppressed);
} finally {
callback.onResult(null, t);
}
Expand Down Expand Up @@ -202,7 +208,11 @@ public <T> T execute(final CommandProtocol<T> protocol, final InternalConnection
protocol.sessionContext(new ClusterClockAdvancingSessionContext(sessionContext, clusterClock));
return protocol.execute(connection);
} catch (MongoException e) {
sdam.handleExceptionAfterHandshake(SdamIssue.specific(e, sdam.context(connection)));
try {
sdam.handleExceptionAfterHandshake(SdamIssue.specific(e, sdam.context(connection)));
} catch (Exception suppressed) {
e.addSuppressed(suppressed);
}
if (e instanceof MongoWriteConcernWithResponseException) {
return (T) ((MongoWriteConcernWithResponseException) e).getResponse();
} else {
Expand All @@ -223,6 +233,8 @@ public <T> void executeAsync(final CommandProtocol<T> protocol, final InternalCo
if (t != null) {
try {
sdam.handleExceptionAfterHandshake(SdamIssue.specific(t, sdam.context(connection)));
} catch (Exception suppressed) {
t.addSuppressed(suppressed);
} finally {
if (t instanceof MongoWriteConcernWithResponseException) {
callback.onResult((T) ((MongoWriteConcernWithResponseException) t).getResponse(), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.channels.ClosedByInterruptException;
import java.util.Collections;
Expand Down Expand Up @@ -705,10 +706,12 @@ private void updateSessionContext(final SessionContext sessionContext, final Res
private MongoException translateWriteException(final Throwable e) {
if (e instanceof MongoException) {
return (MongoException) e;
}
MongoInterruptedException interruptedException = translateInterruptedExceptions(e, "Interrupted while sending message");
if (interruptedException != null) {
return interruptedException;
} else if (e instanceof IOException) {
return new MongoSocketWriteException("Exception sending message", getServerAddress(), e);
} else if (e instanceof InterruptedException) {
return new MongoInternalException("Thread interrupted exception", e);
} else {
return new MongoInternalException("Unexpected exception", e);
}
Expand All @@ -717,23 +720,52 @@ private MongoException translateWriteException(final Throwable e) {
private MongoException translateReadException(final Throwable e) {
if (e instanceof MongoException) {
return (MongoException) e;
}
MongoInterruptedException interruptedException = translateInterruptedExceptions(e, "Interrupted while receiving message");
if (interruptedException != null) {
return interruptedException;
} else if (e instanceof SocketTimeoutException) {
return new MongoSocketReadTimeoutException("Timeout while receiving message", getServerAddress(), e);
} else if (e instanceof InterruptedIOException) {
return new MongoInterruptedException("Interrupted while receiving message", (InterruptedIOException) e);
} else if (e instanceof ClosedByInterruptException) {
return new MongoInterruptedException("Interrupted while receiving message", (ClosedByInterruptException) e);
} else if (e instanceof IOException) {
return new MongoSocketReadException("Exception receiving message", getServerAddress(), e);
} else if (e instanceof RuntimeException) {
return new MongoInternalException("Unexpected runtime exception", e);
} else if (e instanceof InterruptedException) {
return new MongoInternalException("Interrupted exception", e);
} else {
return new MongoInternalException("Unexpected exception", e);
}
}

/**
* @return {@code null} iff {@code e} does not communicate an interrupt.
*/
@Nullable
private static MongoInterruptedException translateInterruptedExceptions(final Throwable e, final String message) {
if (e instanceof InterruptedException) {
// The interrupted status is cleared before throwing `InterruptedException`,
// we are not propagating `InterruptedException`, and we do not own the current thread,
// which means we must reinstate the interrupted status.
Thread.currentThread().interrupt();
return new MongoInterruptedException(message, (InterruptedException) e);
} else if (
// `InterruptedIOException` is weirdly documented, and almost seems to be a relic abandoned by the Java SE APIs:
// - `SocketTimeoutException` is `InterruptedIOException`,
// but it is not related to the Java SE interrupt mechanism. As a side note, it does not happen when writing.
// - Java SE methods, where IO may indeed be interrupted via the Java SE interrupt mechanism,
// use different exceptions, like `ClosedByInterruptException` or even `SocketException`.
(e instanceof InterruptedIOException && !(e instanceof SocketTimeoutException))
// see `java.nio.channels.InterruptibleChannel` and `java.net.Socket.getOutputStream`/`getInputStream`
|| e instanceof ClosedByInterruptException
// see `java.net.Socket.getOutputStream`/`getInputStream`
|| (e instanceof SocketException && Thread.currentThread().isInterrupted())) {
// The interrupted status is not cleared before throwing `ClosedByInterruptException`/`SocketException`,
// so we do not need to reinstate it.
// `InterruptedIOException` does not specify how it behaves with regard to the interrupted status, so we do nothing.
return new MongoInterruptedException(message, (Exception) e);
} else {
return null;
}
}

private ResponseBuffers receiveResponseBuffers(final int additionalTimeout) throws IOException {
ByteBuf messageHeaderBuffer = stream.read(MESSAGE_HEADER_LENGTH, additionalTimeout);
MessageHeader messageHeader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ public ByteBuf read(final int numBytes, final int additionalTimeout) throws IOEx
try {
return read(numBytes);
} finally {
socket.setSoTimeout(curTimeout);
if (!socket.isClosed()) {
// `socket` may be closed if the current thread is virtual, and it is interrupted while reading
socket.setSoTimeout(curTimeout);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.mongodb.internal.connection

import com.mongodb.MongoCommandException
import com.mongodb.MongoInternalException
import com.mongodb.MongoInterruptedException
import com.mongodb.MongoNamespace
import com.mongodb.MongoSocketClosedException
import com.mongodb.MongoSocketException
Expand Down Expand Up @@ -54,6 +55,7 @@ import org.bson.codecs.configuration.CodecConfigurationException
import spock.lang.Specification

import java.nio.ByteBuffer
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
Expand Down Expand Up @@ -312,6 +314,149 @@ class InternalStreamConnectionSpecification extends Specification {
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status set when Stream.write throws InterruptedIOException'() {
given:
stream.write(_) >> { throw new InterruptedIOException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.sendMessage([new ByteBufNIO(ByteBuffer.allocate(1))], 1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status unset when Stream.write throws InterruptedIOException'() {
given:
stream.write(_) >> { throw new InterruptedIOException() }
def connection = getOpenedConnection()

when:
connection.sendMessage([new ByteBufNIO(ByteBuffer.allocate(1))], 1)

then:
!Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status set when Stream.write throws ClosedByInterruptException'() {
given:
stream.write(_) >> { throw new ClosedByInterruptException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.sendMessage([new ByteBufNIO(ByteBuffer.allocate(1))], 1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException when Stream.write throws SocketException and the thread is interrupted'() {
given:
stream.write(_) >> { throw new SocketException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.sendMessage([new ByteBufNIO(ByteBuffer.allocate(1))], 1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoSocketWriteException when Stream.write throws SocketException and the thread is not interrupted'() {
given:
stream.write(_) >> { throw new SocketException() }
def connection = getOpenedConnection()

when:
connection.sendMessage([new ByteBufNIO(ByteBuffer.allocate(1))], 1)

then:
thrown(MongoSocketWriteException)
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status set when Stream.read throws InterruptedIOException'() {
given:
stream.read(_, _) >> { throw new InterruptedIOException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.receiveMessage(1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status unset when Stream.read throws InterruptedIOException'() {
given:
stream.read(_, _) >> { throw new InterruptedIOException() }
def connection = getOpenedConnection()

when:
connection.receiveMessage(1)

then:
!Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException and leave the interrupt status set when Stream.read throws ClosedByInterruptException'() {
given:
stream.read(_, _) >> { throw new ClosedByInterruptException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.receiveMessage(1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoInterruptedException when Stream.read throws SocketException and the thread is interrupted'() {
given:
stream.read(_, _) >> { throw new SocketException() }
def connection = getOpenedConnection()
Thread.currentThread().interrupt()

when:
connection.receiveMessage(1)

then:
Thread.interrupted()
thrown(MongoInterruptedException)
connection.isClosed()
}

def 'should throw MongoSocketReadException when Stream.read throws SocketException and the thread is not interrupted'() {
given:
stream.read(_, _) >> { throw new SocketException() }
def connection = getOpenedConnection()

when:
connection.receiveMessage(1)

then:
thrown(MongoSocketReadException)
connection.isClosed()
}

def 'should close the stream when reading the message header throws an exception asynchronously'() {
given:
Expand Down