Skip to content

Commit add337e

Browse files
storyiconholimanfjl
authored
rpc: support injecting HTTP headers through context (#26023)
This adds a way to specify HTTP headers per request. Co-authored-by: Martin Holst Swende <[email protected]> Co-authored-by: Felix Lange <[email protected]>
1 parent b4ea2bf commit add337e

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

rpc/context_headers.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright 2022 The go-ethereum Authors
2+
// This file is part of the go-ethereum library.
3+
//
4+
// The go-ethereum library is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Lesser General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// The go-ethereum library is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Lesser General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Lesser General Public License
15+
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package rpc
18+
19+
import (
20+
"context"
21+
"net/http"
22+
)
23+
24+
type mdHeaderKey struct{}
25+
26+
// NewContextWithHeaders wraps the given context, adding HTTP headers. These headers will
27+
// be applied by Client when making a request using the returned context.
28+
func NewContextWithHeaders(ctx context.Context, h http.Header) context.Context {
29+
if len(h) == 0 {
30+
// This check ensures the header map set in context will never be nil.
31+
return ctx
32+
}
33+
34+
var ctxh http.Header
35+
prev, ok := ctx.Value(mdHeaderKey{}).(http.Header)
36+
if ok {
37+
ctxh = setHeaders(prev.Clone(), h)
38+
} else {
39+
ctxh = h.Clone()
40+
}
41+
return context.WithValue(ctx, mdHeaderKey{}, ctxh)
42+
}
43+
44+
// headersFromContext is used to extract http.Header from context.
45+
func headersFromContext(ctx context.Context) http.Header {
46+
source, _ := ctx.Value(mdHeaderKey{}).(http.Header)
47+
return source
48+
}
49+
50+
// setHeaders sets all headers from src in dst.
51+
func setHeaders(dst http.Header, src http.Header) http.Header {
52+
for key, values := range src {
53+
dst[http.CanonicalHeaderKey(key)] = values
54+
}
55+
return dst
56+
}

rpc/http.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
217217
hc.mu.Lock()
218218
req.Header = hc.headers.Clone()
219219
hc.mu.Unlock()
220+
setHeaders(req.Header, headersFromContext(ctx))
221+
220222
if hc.auth != nil {
221223
if err := hc.auth(req.Header); err != nil {
222224
return nil, err

rpc/http_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package rpc
1818

1919
import (
20+
"context"
21+
"fmt"
2022
"net/http"
2123
"net/http/httptest"
2224
"strings"
@@ -198,3 +200,43 @@ func TestHTTPPeerInfo(t *testing.T) {
198200
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
199201
}
200202
}
203+
204+
func TestNewContextWithHeaders(t *testing.T) {
205+
expectedHeaders := 0
206+
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
207+
for i := 0; i < expectedHeaders; i++ {
208+
key, want := fmt.Sprintf("key-%d", i), fmt.Sprintf("val-%d", i)
209+
if have := request.Header.Get(key); have != want {
210+
t.Errorf("wrong request headers for %s, want: %s, have: %s", key, want, have)
211+
}
212+
}
213+
writer.WriteHeader(http.StatusOK)
214+
_, _ = writer.Write([]byte(`{}`))
215+
}))
216+
defer server.Close()
217+
218+
client, err := Dial(server.URL)
219+
if err != nil {
220+
t.Fatalf("failed to dial: %s", err)
221+
}
222+
defer client.Close()
223+
224+
newHdr := func(k, v string) http.Header {
225+
header := http.Header{}
226+
header.Set(k, v)
227+
return header
228+
}
229+
ctx1 := NewContextWithHeaders(context.Background(), newHdr("key-0", "val-0"))
230+
ctx2 := NewContextWithHeaders(ctx1, newHdr("key-1", "val-1"))
231+
ctx3 := NewContextWithHeaders(ctx2, newHdr("key-2", "val-2"))
232+
233+
expectedHeaders = 3
234+
if err := client.CallContext(ctx3, nil, "test"); err != ErrNoResult {
235+
t.Error("call failed", err)
236+
}
237+
238+
expectedHeaders = 2
239+
if err := client.CallContext(ctx2, nil, "test"); err != ErrNoResult {
240+
t.Error("call failed:", err)
241+
}
242+
}

0 commit comments

Comments
 (0)