Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package config

import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -38,6 +40,12 @@ var DefaultHTTPClientConfig = HTTPClientConfig{
FollowRedirects: true,
}

// defaultHTTPClientOptions holds the default HTTP client options.
var defaultHTTPClientOptions = httpClientOptions{
keepAlivesEnabled: true,
http2Enabled: true,
}

type closeIdler interface {
CloseIdleConnections()
}
Expand Down Expand Up @@ -194,15 +202,50 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
return unmarshal((*plain)(a))
}

// DialContextFunc defines the signature of the DialContext() function implemented
// by net.Dialer.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)

type httpClientOptions struct {
dialContextFunc DialContextFunc
keepAlivesEnabled bool
http2Enabled bool
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
type HTTPClientOption func(options *httpClientOptions)

// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.dialContextFunc = fn
}
}

// WithKeepAlivesDisabled allows to disable HTTP keepalive.
func WithKeepAlivesDisabled() HTTPClientOption {
return func(opts *httpClientOptions) {
opts.keepAlivesEnabled = false
}
}

// WithHTTP2Disabled allows to disable HTTP2.
func WithHTTP2Disabled() HTTPClientOption {
return func(opts *httpClientOptions) {
opts.http2Enabled = false
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
}

// NewClientFromConfig returns a new HTTP client configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, disableKeepAlives, enableHTTP2)
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewClientFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (*http.Client, error) {
rt, err := NewRoundTripperFromConfig(cfg, name, optFuncs...)
if err != nil {
return nil, err
}
Expand All @@ -216,29 +259,45 @@ func NewClientFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, e
}

// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig. The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, disableKeepAlives, enableHTTP2 bool) (http.RoundTripper, error) {
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
opts := defaultHTTPClientOptions
for _, f := range optFuncs {
f(&opts)
}

var dialContext func(ctx context.Context, network, addr string) (net.Conn, error)

if opts.dialContextFunc != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we build up a list and append to the list when needed, to avoid the repetition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that in first place but the data type is not exported. Not sure how I could do it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Build the list here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which data type should have the slice?

dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithDialContextFunc((func(context.Context, string, string) (net.Conn, error))(opts.dialContextFunc)),
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
} else {
dialContext = conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name))
}

newRT := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
// The only timeout we care about is the configured scrape timeout.
// It is applied on request. So we leave out any timings here.
var rt http.RoundTripper = &http.Transport{
Proxy: http.ProxyURL(cfg.ProxyURL.URL),
MaxIdleConns: 20000,
MaxIdleConnsPerHost: 1000, // see https://github.com/golang/go/issues/13801
DisableKeepAlives: disableKeepAlives,
DisableKeepAlives: !opts.keepAlivesEnabled,
TLSClientConfig: tlsConfig,
DisableCompression: true,
// 5 minutes is typically above the maximum sane scrape interval. So we can
// use keepalive for all configurations.
IdleConnTimeout: 5 * time.Minute,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: conntrack.NewDialContextFunc(
conntrack.DialWithTracing(),
conntrack.DialWithName(name),
),
DialContext: dialContext,
}
if enableHTTP2 {
if opts.http2Enabled {
// HTTP/2 support is golang has many problematic cornercases where
// dead connections would be kept and used in connection pools.
// https://github.com/golang/go/issues/32388
Expand Down
37 changes: 29 additions & 8 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package config

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -50,6 +53,7 @@ const (
MissingKey = "missing/secret.key"

ExpectedMessage = "I'm here to serve you!!!"
ExpectedError = "expected error"
AuthorizationCredentials = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
AuthorizationCredentialsFile = "testdata/bearer.token"
AuthorizationType = "APIKEY"
Expand Down Expand Up @@ -350,7 +354,7 @@ func TestNewClientFromConfig(t *testing.T) {
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test", false, true)
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
Expand Down Expand Up @@ -400,7 +404,7 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
}

for _, invalidConfig := range newClientInvalidConfig {
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test", false, true)
client, err := NewClientFromConfig(invalidConfig.clientConfig, "test")
if client != nil {
t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig)
}
Expand All @@ -413,6 +417,23 @@ func TestNewClientFromInvalidConfig(t *testing.T) {
}
}

func TestCustomDialContextFunc(t *testing.T) {
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
return nil, errors.New(ExpectedError)
}

cfg := HTTPClientConfig{}
client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn))
if err != nil {
t.Fatalf("Can't create a client from this config: %+v", cfg)
}

_, err = client.Get("http://localhost")
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
t.Errorf("Expected error %q but got %q", ExpectedError, err)
}
}

func TestMissingBearerAuthFile(t *testing.T) {
cfg := HTTPClientConfig{
BearerTokenFile: MissingBearerTokenFile,
Expand All @@ -439,7 +460,7 @@ func TestMissingBearerAuthFile(t *testing.T) {
}
defer testServer.Close()

client, err := NewClientFromConfig(cfg, "test", false, true)
client, err := NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -637,7 +658,7 @@ func TestBasicAuthNoPassword(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand All @@ -663,7 +684,7 @@ func TestBasicAuthNoUsername(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand All @@ -689,7 +710,7 @@ func TestBasicAuthPasswordFile(t *testing.T) {
if err != nil {
t.Fatalf("Error loading HTTP client config: %v", err)
}
client, err := NewClientFromConfig(*cfg, "test", false, true)
client, err := NewClientFromConfig(*cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down Expand Up @@ -840,7 +861,7 @@ func TestTLSRoundTripper(t *testing.T) {
writeCertificate(bs, tc.cert, cert)
writeCertificate(bs, tc.key, key)
if c == nil {
c, err = NewClientFromConfig(cfg, "test", false, true)
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down Expand Up @@ -912,7 +933,7 @@ func TestTLSRoundTripperRaces(t *testing.T) {
writeCertificate(bs, TLSCAChainPath, ca)
writeCertificate(bs, ClientCertificatePath, cert)
writeCertificate(bs, ClientKeyNoPassPath, key)
c, err = NewClientFromConfig(cfg, "test", false, true)
c, err = NewClientFromConfig(cfg, "test")
if err != nil {
t.Fatalf("Error creating HTTP Client: %v", err)
}
Expand Down