Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.google.common.io.Files;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
Expand All @@ -69,6 +70,7 @@
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -146,7 +148,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;
@Nullable private final SecureSessionAgent s2aConfigProvider;
@Nullable private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();

@Nullable
Expand Down Expand Up @@ -175,7 +177,10 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.headerProvider = builder.headerProvider;
this.useS2A = builder.useS2A;
this.endpoint = builder.endpoint;
this.allowedHardBoundTokenTypes = builder.allowedHardBoundTokenTypes;
this.allowedHardBoundTokenTypes =
builder.allowedHardBoundTokenTypes == null
? new ArrayList<>()
: builder.allowedHardBoundTokenTypes;
this.mtlsProvider = builder.mtlsProvider;
this.s2aConfigProvider = builder.s2aConfigProvider;
this.envProvider = builder.envProvider;
Expand Down Expand Up @@ -592,6 +597,35 @@ ChannelCredentials createS2ASecuredChannelCredentials() {
}
}

boolean isMtlsS2AHardBoundTokensEnabled() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought (nothing that needs to be changed in this PR): With how many helper methods we have for S2A and hard bound tokens, I wonder if we can split these methods into a helper class in Gax-Grpc (something like S2AMtlsContext or something)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this would be help to reduce the complexity in the InstantiatingGrpcChannelProvider file. I'm happy to do the cleanup of that in a followup CL.

if (!useS2A
// If S2A cannot be used, {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens should not
// be used
|| allowedHardBoundTokenTypes.isEmpty()
|| credentials == null
|| !(credentials instanceof ComputeEngineCredentials)) {
return false;
}
return allowedHardBoundTokenTypes.stream()
.anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A));
}

CallCredentials createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
ComputeEngineCredentials.Builder credsBuilder =
((ComputeEngineCredentials) credentials).toBuilder();
// We only set scopes and HTTP transport factory from the original credentials because
// only those are used in gRPC CallCredentials to fetch request metadata.
return MoreCallCredentials.from(
ComputeEngineCredentials.newBuilder()
.setScopes(credsBuilder.getScopes())
.setHttpTransportFactory(credsBuilder.getHttpTransportFactory())
.setGoogleAuthTransport(googleAuthTransport)
.setBindingEnforcement(bindingEnforcement)
.build());
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headersWithDuplicatesRemoved);
Expand Down Expand Up @@ -648,6 +682,15 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
if (isMtlsS2AHardBoundTokensEnabled()) {
// Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
CallCredentials callCreds =
createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
ComputeEngineCredentials.BindingEnforcement.ON);
channelCredentials = CompositeChannelCredentials.create(channelCredentials, callCreds);
}
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
Expand Down Expand Up @@ -818,7 +861,7 @@ public static final class Builder {
@Nullable private Boolean attemptDirectPathXds;
@Nullable private Boolean allowNonDefaultServiceAccount;
@Nullable private ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;

private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
Expand Down Expand Up @@ -846,6 +889,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.attemptDirectPath = provider.attemptDirectPath;
this.attemptDirectPathXds = provider.attemptDirectPathXds;
this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount;
this.allowedHardBoundTokenTypes = provider.allowedHardBoundTokenTypes;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
this.s2aConfigProvider = provider.s2aConfigProvider;
Expand Down Expand Up @@ -914,7 +958,10 @@ Builder setUseS2A(boolean useS2A) {
*/
@InternalApi
public Builder setAllowHardBoundTokenTypes(List<HardBoundTokenTypes> allowedValues) {
this.allowedHardBoundTokenTypes = allowedValues;
this.allowedHardBoundTokenTypes =
Preconditions.checkNotNull(
allowedValues, "Illegal Argument, allowedValues cannot be null");
;
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,82 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
}

@Test
void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(false)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesEmpty() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(new ArrayList<>())
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_nullCreds() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(null)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(CloudShellCredentials.create(3000))
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_returnsTrue() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isTrue();
}

private static class FakeLogHandler extends Handler {

List<LogRecord> records = new ArrayList<>();
Expand Down
Loading