Skip to content

Commit bd829bb

Browse files
committed
Adds support for EC2Metadata client secure token
1 parent dadd7ec commit bd829bb

File tree

5 files changed

+1236
-190
lines changed

5 files changed

+1236
-190
lines changed

aws/ec2metadata/api_client.go

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net"
1515
"net/http"
1616
"os"
17+
"strconv"
1718
"strings"
1819
"time"
1920

@@ -22,7 +23,25 @@ import (
2223
"github.com/aws/aws-sdk-go-v2/aws/defaults"
2324
)
2425

25-
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
26+
const (
27+
// ServiceName is the name of the service.
28+
ServiceName = "ec2metadata"
29+
disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
30+
31+
// Headers for Token and TTL
32+
ttlHeader = "x-aws-ec2-metadata-token-ttl-seconds"
33+
tokenHeader = "x-aws-ec2-metadata-token"
34+
35+
// Named Handler constants
36+
fetchTokenHandlerName = "FetchTokenHandler"
37+
unmarshalMetadataHandlerName = "unmarshalMetadataHandler"
38+
unmarshalTokenHandlerName = "unmarshalTokenHandler"
39+
enableTokenProviderHandlerName = "enableTokenProviderHandler"
40+
41+
// TTL constants
42+
defaultTTL = 21600 * time.Second
43+
ttlExpirationWindow = 30 * time.Second
44+
)
2645

2746
// A Client is an EC2 Instance Metadata service Client.
2847
type Client struct {
@@ -61,7 +80,20 @@ func New(config aws.Config) *Client {
6180
),
6281
}
6382

64-
svc.Handlers.Unmarshal.PushBack(unmarshalHandler)
83+
// token provider instance
84+
tp := newTokenProvider(svc, defaultTTL)
85+
// NamedHandler for fetching token
86+
svc.Handlers.Sign.PushBackNamed(aws.NamedHandler{
87+
Name: fetchTokenHandlerName,
88+
Fn: tp.fetchTokenHandler,
89+
})
90+
// NamedHandler for enabling token provider
91+
svc.Handlers.Complete.PushBackNamed(aws.NamedHandler{
92+
Name: enableTokenProviderHandlerName,
93+
Fn: tp.enableTokenProviderHandler,
94+
})
95+
96+
svc.Handlers.Unmarshal.PushBackNamed(unmarshalHandler)
6597
svc.Handlers.UnmarshalError.PushBack(unmarshalError)
6698
svc.Handlers.Validate.Clear()
6799
svc.Handlers.Validate.PushBack(validateEndpointHandler)
@@ -91,30 +123,74 @@ type metadataOutput struct {
91123
Content string
92124
}
93125

94-
func unmarshalHandler(r *aws.Request) {
95-
defer r.HTTPResponse.Body.Close()
96-
b := &bytes.Buffer{}
97-
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
98-
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err)
99-
return
100-
}
126+
type tokenOutput struct {
127+
Token string
128+
TTL time.Duration
129+
}
101130

102-
if data, ok := r.Data.(*metadataOutput); ok {
103-
data.Content = b.String()
104-
}
131+
// unmarshal token handler is used to parse the response of a getToken operation
132+
var unmarshalTokenHandler = aws.NamedHandler{
133+
Name: unmarshalTokenHandlerName,
134+
Fn: func(r *aws.Request) {
135+
defer r.HTTPResponse.Body.Close()
136+
var b bytes.Buffer
137+
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
138+
r.Error = awserr.NewRequestFailure(awserr.New(aws.ErrCodeSerialization,
139+
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
140+
return
141+
}
142+
143+
v := r.HTTPResponse.Header.Get(ttlHeader)
144+
data, ok := r.Data.(*tokenOutput)
145+
if !ok {
146+
return
147+
}
148+
149+
data.Token = b.String()
150+
// TTL is in seconds
151+
i, err := strconv.ParseInt(v, 10, 64)
152+
if err != nil {
153+
r.Error = awserr.NewRequestFailure(awserr.New(aws.ParamFormatErrCode,
154+
"unable to parse EC2 token TTL response", err), r.HTTPResponse.StatusCode, r.RequestID)
155+
return
156+
}
157+
t := time.Duration(i) * time.Second
158+
data.TTL = t
159+
},
160+
}
161+
162+
var unmarshalHandler = aws.NamedHandler{
163+
Name: unmarshalMetadataHandlerName,
164+
Fn: func(r *aws.Request) {
165+
defer r.HTTPResponse.Body.Close()
166+
var b bytes.Buffer
167+
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
168+
r.Error = awserr.NewRequestFailure(awserr.New(aws.ErrCodeSerialization,
169+
"unable to unmarshal EC2 metadata response", err), r.HTTPResponse.StatusCode, r.RequestID)
170+
return
171+
}
172+
173+
if data, ok := r.Data.(*metadataOutput); ok {
174+
data.Content = b.String()
175+
}
176+
},
105177
}
106178

107179
func unmarshalError(r *aws.Request) {
108180
defer r.HTTPResponse.Body.Close()
109-
b := &bytes.Buffer{}
110-
if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil {
111-
r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err)
181+
var b bytes.Buffer
182+
183+
if _, err := io.Copy(&b, r.HTTPResponse.Body); err != nil {
184+
r.Error = awserr.NewRequestFailure(
185+
awserr.New(aws.ErrCodeSerialization, "unable to unmarshal EC2 metadata error response", err),
186+
r.HTTPResponse.StatusCode, r.RequestID)
112187
return
113188
}
114189

115190
// Response body format is not consistent between metadata endpoints.
116191
// Grab the error message as a string and include that as the source error
117-
r.Error = awserr.New("EC2MetadataError", "failed to make Client request", errors.New(b.String()))
192+
r.Error = awserr.NewRequestFailure(awserr.New("EC2MetadataError", "failed to make EC2Metadata request", errors.New(b.String())),
193+
r.HTTPResponse.StatusCode, r.RequestID)
118194
}
119195

120196
func validateEndpointHandler(r *aws.Request) {

aws/ec2metadata/api_client_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package ec2metadata_test
22

33
import (
4-
"os"
5-
"strings"
6-
"testing"
7-
84
"github.com/aws/aws-sdk-go-v2/aws"
95
"github.com/aws/aws-sdk-go-v2/aws/awserr"
106
"github.com/aws/aws-sdk-go-v2/aws/ec2metadata"
117
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
128
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
9+
"os"
10+
"strings"
11+
"testing"
1312
)
1413

1514
func TestClientDisableIMDS(t *testing.T) {
@@ -23,7 +22,7 @@ func TestClientDisableIMDS(t *testing.T) {
2322
cfg.Logger = t
2423

2524
svc := ec2metadata.New(cfg)
26-
resp, err := svc.Region()
25+
resp, err := svc.GetUserData()
2726
if err == nil {
2827
t.Fatalf("expect error, got none")
2928
}

aws/ec2metadata/api_ops.go

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,49 @@ import (
55
"fmt"
66
"net/http"
77
"path"
8+
"strconv"
89
"strings"
910
"time"
1011

1112
"github.com/aws/aws-sdk-go-v2/aws"
1213
"github.com/aws/aws-sdk-go-v2/aws/awserr"
1314
)
1415

16+
// getToken uses the duration to return a token for EC2 metadata service,
17+
// or an error if the request failed.
18+
func (c *Client) getToken(duration time.Duration) (tokenOutput, error) {
19+
op := &aws.Operation{
20+
Name: "GetToken",
21+
HTTPMethod: "PUT",
22+
HTTPPath: "/api/token",
23+
}
24+
25+
var output tokenOutput
26+
req := c.NewRequest(op, nil, &output)
27+
28+
// remove the fetch token handler from the request handlers to avoid infinite recursion
29+
req.Handlers.Sign.RemoveByName(fetchTokenHandlerName)
30+
31+
// Swap the unmarshalMetadataHandler with unmarshalTokenHandler on this request.
32+
req.Handlers.Unmarshal.Swap(unmarshalMetadataHandlerName, unmarshalTokenHandler)
33+
34+
ttl := strconv.FormatInt(int64(duration/time.Second), 10)
35+
req.HTTPRequest.Header.Set(ttlHeader, ttl)
36+
37+
err := req.Send()
38+
39+
// Errors with bad request status should be returned.
40+
if err != nil {
41+
err = awserr.NewRequestFailure(
42+
awserr.New(req.HTTPResponse.Status, http.StatusText(req.HTTPResponse.StatusCode), err),
43+
req.HTTPResponse.StatusCode, req.RequestID)
44+
}
45+
46+
return output, err
47+
}
48+
1549
// GetMetadata uses the path provided to request information from the EC2
16-
// instance metdata service. The content will be returned as a string, or
50+
// instance metadata service. The content will be returned as a string, or
1751
// error if the request failed.
1852
func (c *Client) GetMetadata(p string) (string, error) {
1953
op := &aws.Operation{
@@ -40,12 +74,6 @@ func (c *Client) GetUserData() (string, error) {
4074

4175
output := &metadataOutput{}
4276
req := c.NewRequest(op, nil, output)
43-
req.Handlers.UnmarshalError.PushBack(func(r *aws.Request) {
44-
if r.HTTPResponse.StatusCode == http.StatusNotFound {
45-
r.Error = awserr.New("NotFoundError", "user-data not found", r.Error)
46-
}
47-
})
48-
4977
return output.Content, req.Send()
5078
}
5179

@@ -113,13 +141,17 @@ func (c *Client) IAMInfo() (EC2IAMInfo, error) {
113141

114142
// Region returns the region the instance is running in.
115143
func (c *Client) Region() (string, error) {
116-
resp, err := c.GetMetadata("placement/availability-zone")
144+
ec2InstanceIdentityDocument, err := c.GetInstanceIdentityDocument()
117145
if err != nil {
118146
return "", err
119147
}
120-
121-
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
122-
return resp[:len(resp)-1], nil
148+
// extract region from the ec2InstanceIdentityDocument
149+
region := ec2InstanceIdentityDocument.Region
150+
if len(region) == 0 {
151+
return "", awserr.New("EC2MetadataError", "invalid region received for ec2metadata instance", nil)
152+
}
153+
// returns region
154+
return region, nil
123155
}
124156

125157
// Available returns if the application has access to the EC2 Instance Metadata

0 commit comments

Comments
 (0)