Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import java.util.function.Function;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider;
import software.amazon.awssdk.services.s3.model.S3Request;

@SdkInternalApi
Expand Down Expand Up @@ -67,28 +66,12 @@ private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3
S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());

// TODO : separate PR to provide supplier for Async client
return requestOverrideConfig.toBuilder()
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, null))
.build();
}

//TODO: add cross region logic
static final class BucketEndpointProvider implements S3EndpointProvider {
private final S3EndpointProvider delegate;
private final String bucket;

private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
this.delegate = delegate;
this.bucket = bucket;
}

public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
return new BucketEndpointProvider(delegate, bucket);
}

@Override
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
return delegate.resolveEndpoint(endpointParams);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,80 +15,110 @@

package software.amazon.awssdk.services.s3.internal.crossregion;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.DelegatingS3Client;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;
import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider;
import software.amazon.awssdk.services.s3.model.HeadBucketRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.s3.model.S3Request;

/**
* Decorator S3 Sync client that will fetch the region name whenever there is Redirect 301 error due to cross region bucket
* access.
*/
@SdkInternalApi
public final class S3CrossRegionSyncClient extends DelegatingS3Client {

private static final String AMZ_BUCKET_REGION_HEADER = "x-amz-bucket-region";
public static final int REDIRECT_STATUS_CODE = 301;
private final Map<String, Region> bucketToRegionCache = new ConcurrentHashMap<>();

public S3CrossRegionSyncClient(S3Client s3Client) {
super(s3Client);
}

private static <T extends S3Request> Optional<String> bucketNameFromRequest(T request) {
return request.getValueForField("Bucket", String.class);
}

@Override
protected <T extends S3Request, ReturnT> ReturnT invokeOperation(T request, Function<T, ReturnT> operation) {

Optional<String> bucket = request.getValueForField("Bucket", String.class);

if (bucket.isPresent()) {
try {
return operation.apply(requestWithDecoratedEndpointProvider(request, bucket.get()));
} catch (Exception e) {
handleOperationFailure(e, bucket.get());
Optional<String> bucketRequest = bucketNameFromRequest(request);
if (!bucketRequest.isPresent()) {
return operation.apply(request);
}
String bucketName = bucketRequest.get();
try {
if (bucketToRegionCache.containsKey(bucketName)) {
return operation.apply(requestWithDecoratedEndpointProvider(request, regionSupplier(bucketName)));
}
return operation.apply(request);
} catch (S3Exception exception) {
if (exception.statusCode() == REDIRECT_STATUS_CODE) {
updateCacheFromRedirectException(exception, bucketName);
return operation.apply(requestWithDecoratedEndpointProvider(request, regionSupplier(bucketName)));
}
throw exception;
}
}

private String updateCacheFromRedirectException(S3Exception exception, String bucketName) {
Optional<String> regionStr = getBucketRegionFromException(exception);
// If redirected, clear previous values due to region change. bucketToRegionCache.remove(bucketName);
regionStr.ifPresent(region -> bucketToRegionCache.put(bucketName, Region.of(region)));
return regionStr.orElse(null);
}

return operation.apply(request);
private Supplier<Region> regionSupplier(String bucket) {
return () -> bucketToRegionCache.computeIfAbsent(bucket, this::fetchBucketRegion);
}

private void handleOperationFailure(Throwable t, String bucket) {
//TODO: handle failure case
private Region fetchBucketRegion(String bucketName) {
try {
((S3Client) delegate()).headBucket(HeadBucketRequest.builder().bucket(bucketName).build());
} catch (S3Exception exception) {
if (exception.statusCode() == REDIRECT_STATUS_CODE) {
return Region.of(getBucketRegionFromException(exception).orElseThrow(() -> exception));
}
throw exception;
}
return null;
}

@SuppressWarnings("unchecked")
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, String bucket) {
private <T extends S3Request> T requestWithDecoratedEndpointProvider(T request, Supplier<Region> regionSupplier) {
return (T) request.toBuilder()
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, bucket))
.overrideConfiguration(getOrCreateConfigWithEndpointProvider(request, regionSupplier))
.build();
}

//TODO: optimize shared sync/async code
private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request, String bucket) {
private AwsRequestOverrideConfiguration getOrCreateConfigWithEndpointProvider(S3Request request,
Supplier<Region> regionSupplier) {
AwsRequestOverrideConfiguration requestOverrideConfig =
request.overrideConfiguration().orElseGet(() -> AwsRequestOverrideConfiguration.builder().build());

S3EndpointProvider delegateEndpointProvider = (S3EndpointProvider)
requestOverrideConfig.endpointProvider().orElseGet(() -> serviceClientConfiguration().endpointProvider().get());

return requestOverrideConfig.toBuilder()
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, bucket))
.endpointProvider(BucketEndpointProvider.create(delegateEndpointProvider, regionSupplier))
.build();
}

static final class BucketEndpointProvider implements S3EndpointProvider {
private final S3EndpointProvider delegate;
private final String bucket;

private BucketEndpointProvider(S3EndpointProvider delegate, String bucket) {
this.delegate = delegate;
this.bucket = bucket;
}

public static BucketEndpointProvider create(S3EndpointProvider delegate, String bucket) {
return new BucketEndpointProvider(delegate, bucket);
}

@Override
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
return delegate.resolveEndpoint(endpointParams);
}
private Optional<String> getBucketRegionFromException(S3Exception exception) {
return exception.awsErrorDetails()
.sdkHttpResponse()
.firstMatchingHeader(AMZ_BUCKET_REGION_HEADER);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider;

import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams;
import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider;

/**
* Decorator S3EndpointProvider which updates the region with the one that is supplied during its instantiation.
*/
@SdkInternalApi
public class BucketEndpointProvider implements S3EndpointProvider {
private final S3EndpointProvider delegateEndPointProvider;
private final Supplier<Region> regionSupplier;

private BucketEndpointProvider(S3EndpointProvider delegateEndPointProvider, Supplier<Region> regionSupplier) {
this.delegateEndPointProvider = delegateEndPointProvider;
this.regionSupplier = regionSupplier;
}

public static BucketEndpointProvider create(S3EndpointProvider delegateEndPointProvider, Supplier<Region> regionSupplier) {
return new BucketEndpointProvider(delegateEndPointProvider, regionSupplier);
}

@Override
public CompletableFuture<Endpoint> resolveEndpoint(S3EndpointParams endpointParams) {
Region crossRegion = regionSupplier.get();
return delegateEndPointProvider.resolveEndpoint(
crossRegion != null ? endpointParams.toBuilder().region(crossRegion).build() : endpointParams);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.internal.DefaultS3EndpointProvider;
import software.amazon.awssdk.services.s3.internal.crossregion.endpointprovider.BucketEndpointProvider;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;
Expand Down Expand Up @@ -71,7 +72,7 @@ public void before() {
public void standardOp_crossRegionClient_noOverrideConfig_SuccessfullyIntercepts() {
S3AsyncClient crossRegionClient = new S3CrossRegionAsyncClient(s3Client);
crossRegionClient.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes());
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class);
}

@Test
Expand All @@ -83,7 +84,7 @@ public void standardOp_crossRegionClient_existingOverrideConfig_SuccessfullyInte
.overrideConfiguration(o -> o.putHeader("someheader", "somevalue"))
.build();
crossRegionClient.getObject(request, AsyncResponseTransformer.toBytes());
assertThat(captureInterceptor.endpointProvider).isInstanceOf(S3CrossRegionAsyncClient.BucketEndpointProvider.class);
assertThat(captureInterceptor.endpointProvider).isInstanceOf(BucketEndpointProvider.class);
assertThat(mockAsyncHttpClient.getLastRequest().headers().get("someheader")).isNotNull();
}

Expand Down
Loading