diff --git a/http2/transport.go b/http2/transport.go index 4ec326699..a5b4975d9 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -130,6 +130,51 @@ type Transport struct { connPoolOrDef ClientConnPool // non-nil version of ConnPool } +// transportOptions configure the ConfigureTransport. transportOptions are set by the TransportOption +// values passed to Transport. +type transportOptions struct { + // readIdleTimeout corresponds to Transport.ReadIdleTimeout + readIdleTimeout time.Duration + + // pingTimeout corresponds to Transport.PingTimeout + pingTimeout time.Duration +} + +// TransportOption configures how we set up the extra parameters(such as http2 health check) of Transport besides http.Transport when call ConfigureTransport. +type TransportOption interface { + apply(*transportOptions) +} + +// funcTransportOption wraps a function that modifies transportOptions into an +// implementation of the TransportOption interface. +type funcTransportOption struct { + f func(*transportOptions) +} + +func (fto *funcTransportOption) apply(to *transportOptions) { + fto.f(to) +} + +func newFuncTransportOption(f func(*transportOptions)) *funcTransportOption { + return &funcTransportOption{ + f: f, + } +} + +// WithReadIdleTimeout returns a TransportOption which sets the Transport.ReadIdleTimeout +func WithReadIdleTimeout(readIdleTimeout time.Duration) TransportOption { + return newFuncTransportOption(func(o *transportOptions) { + o.readIdleTimeout = readIdleTimeout + }) +} + +// WithPingTimeout returns a TransportOption which sets the Transport.PingTimeout +func WithPingTimeout(pingTimeout time.Duration) TransportOption { + return newFuncTransportOption(func(o *transportOptions) { + o.pingTimeout = pingTimeout + }) +} + func (t *Transport) maxHeaderListSize() uint32 { if t.MaxHeaderListSize == 0 { return 10 << 20 @@ -154,17 +199,25 @@ func (t *Transport) pingTimeout() time.Duration { // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. // It returns an error if t1 has already been HTTP/2-enabled. -func ConfigureTransport(t1 *http.Transport) error { - _, err := configureTransport(t1) +func ConfigureTransport(t1 *http.Transport, opts ...TransportOption) error { + _, err := configureTransport(t1, opts...) return err } -func configureTransport(t1 *http.Transport) (*Transport, error) { +func configureTransport(t1 *http.Transport, opts ...TransportOption) (*Transport, error) { + t2Opts := transportOptions{} + for _, o := range opts { + o.apply(&t2Opts) + } connPool := new(clientConnPool) t2 := &Transport{ ConnPool: noDialClientConnPool{connPool}, t1: t1, } + if t2Opts.readIdleTimeout != 0 { + t2.ReadIdleTimeout = t2Opts.readIdleTimeout + t2.PingTimeout = t2Opts.pingTimeout + } connPool.t = t2 if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil { return nil, err diff --git a/http2/transport_test.go b/http2/transport_test.go index 2fdb117ac..37181337f 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -32,6 +32,7 @@ import ( "sync/atomic" "testing" "time" + "unsafe" "golang.org/x/net/http2/hpack" ) @@ -594,9 +595,29 @@ func TestTransportDialTLS(t *testing.T) { } } +func TestConfigureTransportWithOptions(t *testing.T) { + t1 := &http.Transport{} + err := ConfigureTransport(t1, WithReadIdleTimeout(10*time.Second), WithPingTimeout(2*time.Second)) + if err != nil { + t.Fatal(err) + } + rf := reflect.ValueOf(t1).Elem().FieldByName("altProto") + rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + v := rf.Interface().(atomic.Value) + altProto := v.Load().(map[string]http.RoundTripper) + rt := (altProto["https"]).(noDialH2RoundTripper) + t2 := rt.Transport + if t2.ReadIdleTimeout != 10*time.Second { + t.Errorf("expected ReadIdleTimeout to be 10s, got %v", t2.ReadIdleTimeout) + } + if t2.PingTimeout != 2*time.Second { + t.Errorf("expected PingTimeout to be 2s, got %v", t2.PingTimeout) + } +} + func TestConfigureTransport(t *testing.T) { t1 := &http.Transport{} - err := ConfigureTransport(t1) + err := ConfigureTransport(t1, WithReadIdleTimeout(3*time.Second), WithPingTimeout(1*time.Second)) if err != nil { t.Fatal(err) }