Skip to content

Implements IsCredentialsProvider for checking if a provider matches a target provider type. #1890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .changelog/1b61ec1ce18c4cdfae74f8852ecbf877.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "1b61ec1c-e18c-4cdf-ae74-f8852ecbf877",
"type": "bugfix",
"description": "The SDK client has been updated to utilize the `aws.IsCredentialsProvider` function for determining if `aws.AnonymousCredentials` has been configured for the `CredentialProvider`.",
"modules": [
"service/eventbridge",
"service/s3"
]
}
8 changes: 8 additions & 0 deletions .changelog/869890a030aa4f8e8ddd5ef80b7a01df.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "869890a0-30aa-4f8e-8ddd-5ef80b7a01df",
"type": "feature",
"description": "Adds `aws.IsCredentialsProvider` for inspecting `CredentialProvider` types when needing to determine if the underlying implementation type matches a target type. This resolves an issue where `CredentialsCache` could mask `AnonymousCredentials` providers, breaking downstream detection logic.",
"modules": [
"."
]
}
6 changes: 6 additions & 0 deletions aws/credential_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ func (p *CredentialsCache) Invalidate() {
p.creds.Store((*Credentials)(nil))
}

// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
// matches the target provider type.
func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
return IsCredentialsProvider(p.provider, target)
}

// HandleFailRefreshCredentialsCacheStrategy is an interface for
// CredentialsCache to allow CredentialsProvider how failed to refresh
// credentials is handled.
Expand Down
43 changes: 43 additions & 0 deletions aws/credential_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,46 @@ func (m mockAdjustExpiryBy) AdjustExpiresBy(creds Credentials, dur time.Duration
}
return m.creds, m.err
}

func TestCredentialsCache_IsCredentialsProvider(t *testing.T) {
tests := map[string]struct {
provider CredentialsProvider
target CredentialsProvider
want bool
}{
"nil provider and target": {
provider: nil,
target: nil,
want: true,
},
"matches value implementations": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"matches value and pointer implementations, wrapped pointer": {
provider: NewCredentialsCache(&AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"matches value and pointer implementations, pointer target": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &AnonymousCredentials{},
want: true,
},
"does not match mismatched provider types": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &stubCredentialsProvider{},
want: false,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := NewCredentialsCache(tt.provider).IsCredentialsProvider(tt.target); got != tt.want {
t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want)
}
})
}
}

var _ isCredentialsProvider = (*CredentialsCache)(nil)
39 changes: 39 additions & 0 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aws
import (
"context"
"fmt"
"reflect"
"time"

"github.com/aws/aws-sdk-go-v2/internal/sdk"
Expand Down Expand Up @@ -129,3 +130,41 @@ type CredentialsProviderFunc func(context.Context) (Credentials, error)
func (fn CredentialsProviderFunc) Retrieve(ctx context.Context) (Credentials, error) {
return fn(ctx)
}

type isCredentialsProvider interface {
IsCredentialsProvider(CredentialsProvider) bool
}

// IsCredentialsProvider returns whether the target CredentialProvider is the same type as provider when comparing the
// implementation type.
//
// If provider has a method IsCredentialsProvider(CredentialsProvider) bool it will be responsible for validating
// whether target matches the credential provider type.
//
// When comparing the CredentialProvider implementations provider and target for equality, the following rules are used:
//
// If provider is of type T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type V, true if type *T is the same as type *V, otherwise false
// If provider is of type T and target is of type *V, true if type *T is the same as type *V, otherwise false
// If provider is of type *T and target is of type *V,true if type *T is the same as type *V, otherwise false
func IsCredentialsProvider(provider, target CredentialsProvider) bool {
if target == nil || provider == nil {
return provider == target
}

if x, ok := provider.(isCredentialsProvider); ok {
return x.IsCredentialsProvider(target)
}

targetType := reflect.TypeOf(target)
if targetType.Kind() != reflect.Ptr {
targetType = reflect.PtrTo(targetType)
}

providerType := reflect.TypeOf(provider)
if providerType.Kind() != reflect.Ptr {
providerType = reflect.PtrTo(providerType)
}

return targetType.AssignableTo(providerType)
}
83 changes: 83 additions & 0 deletions aws/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package aws

import (
"context"
"testing"
)

type anonymousNamedType AnonymousCredentials

func (f anonymousNamedType) Retrieve(ctx context.Context) (Credentials, error) {
return AnonymousCredentials(f).Retrieve(ctx)
}

func TestIsCredentialsProvider(t *testing.T) {
tests := map[string]struct {
provider CredentialsProvider
target CredentialsProvider
want bool
}{
"same implementations": {
provider: AnonymousCredentials{},
target: AnonymousCredentials{},
want: true,
},
"same implementations, pointer target": {
provider: AnonymousCredentials{},
target: &AnonymousCredentials{},
want: true,
},
"same implementations, pointer provider": {
provider: &AnonymousCredentials{},
target: AnonymousCredentials{},
want: true,
},
"same implementations, both pointers": {
provider: &AnonymousCredentials{},
target: &AnonymousCredentials{},
want: true,
},
"different implementations, nil target": {
provider: AnonymousCredentials{},
target: nil,
want: false,
},
"different implementations, nil provider": {
provider: nil,
target: AnonymousCredentials{},
want: false,
},
"different implementations": {
provider: AnonymousCredentials{},
target: &stubCredentialsProvider{},
want: false,
},
"nil provider and target": {
provider: nil,
target: nil,
want: true,
},
"implements IsCredentialsProvider, match": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: AnonymousCredentials{},
want: true,
},
"implements IsCredentialsProvider, no match": {
provider: NewCredentialsCache(AnonymousCredentials{}),
target: &stubCredentialsProvider{},
want: false,
},
"named types aliasing underlying types": {
provider: AnonymousCredentials{},
target: anonymousNamedType{},
want: false,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := IsCredentialsProvider(tt.provider, tt.target); got != tt.want {
t.Errorf("IsCredentialsProvider() = %v, want %v", got, tt.want)
}
})
}
}
7 changes: 1 addition & 6 deletions aws/signer/v4/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,8 @@ func haveCredentialProvider(p aws.CredentialsProvider) bool {
if p == nil {
return false
}
switch p.(type) {
case aws.AnonymousCredentials,
*aws.AnonymousCredentials:
return false
}

return true
return !aws.IsCredentialsProvider(p, (*aws.AnonymousCredentials)(nil))
}

type payloadHashKey struct{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) {
AwsCustomGoDependency.INTERNAL_SIGV4A).build());
writer.putContext("anonType", SymbolUtils.createPointableSymbolBuilder("AnonymousCredentials",
AwsCustomGoDependency.AWS_CORE).build());
writer.putContext("isProvider", SymbolUtils.createValueSymbolBuilder("IsCredentialsProvider",
AwsCustomGoDependency.AWS_CORE).build());
writer.putContext("adapType", SymbolUtils.createPointableSymbolBuilder("SymmetricCredentialAdaptor",
AwsCustomGoDependency.INTERNAL_SIGV4A).build());
writer.write("""
Expand All @@ -54,9 +56,8 @@ public static void writeCredentialProviderResolver(GoWriter writer) {
return
}

switch o.$fieldName:L.(type) {
case $anonType:T, $anonType:P:
return
if $isProvider:T(o.$fieldName:L, ($anonType:P)(nil)) {
return
}

o.$fieldName:L = &$adapType:T{SymmetricProvider: o.$fieldName:L}
Expand Down
3 changes: 1 addition & 2 deletions service/eventbridge/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions service/s3/api_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.