diff --git a/go.mod b/go.mod index bd0c6af67c1..e5eb85da5fb 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/golang/snappy v0.0.4 github.com/gorilla/mux v1.8.0 github.com/grafana/dskit v0.0.0-20220105080720-01ce9286d7d5 + github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/json-iterator/go v1.1.12 github.com/lib/pq v1.3.0 github.com/minio/minio-go/v7 v7.0.10 diff --git a/pkg/alertmanager/alertmanager_client.go b/pkg/alertmanager/alertmanager_client.go index 57571a31464..e95d8708ae1 100644 --- a/pkg/alertmanager/alertmanager_client.go +++ b/pkg/alertmanager/alertmanager_client.go @@ -5,8 +5,6 @@ import ( "time" "github.com/go-kit/log" - "github.com/grafana/dskit/crypto/tls" - "github.com/grafana/dskit/grpcclient" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -15,6 +13,8 @@ import ( "github.com/cortexproject/cortex/pkg/alertmanager/alertmanagerpb" "github.com/cortexproject/cortex/pkg/ring/client" + "github.com/cortexproject/cortex/pkg/util/grpcclient" + "github.com/cortexproject/cortex/pkg/util/tls" ) // ClientsPool is the interface used to get the client from the pool for a specified address. diff --git a/pkg/chunk/gcp/bigtable_index_client.go b/pkg/chunk/gcp/bigtable_index_client.go index 7b516212aab..fc716edbb5b 100644 --- a/pkg/chunk/gcp/bigtable_index_client.go +++ b/pkg/chunk/gcp/bigtable_index_client.go @@ -12,12 +12,12 @@ import ( "cloud.google.com/go/bigtable" "github.com/go-kit/log" - "github.com/grafana/dskit/grpcclient" ot "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/cortexproject/cortex/pkg/chunk" chunk_util "github.com/cortexproject/cortex/pkg/chunk/util" + "github.com/cortexproject/cortex/pkg/util/grpcclient" "github.com/cortexproject/cortex/pkg/util/math" "github.com/cortexproject/cortex/pkg/util/spanlogger" ) diff --git a/pkg/configs/client/client.go b/pkg/configs/client/client.go index 0e7561aa6bc..4aaebc37d55 100644 --- a/pkg/configs/client/client.go +++ b/pkg/configs/client/client.go @@ -12,7 +12,6 @@ import ( "time" "github.com/go-kit/log/level" - dstls "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -21,6 +20,7 @@ import ( "github.com/cortexproject/cortex/pkg/configs/userconfig" util_log "github.com/cortexproject/cortex/pkg/util/log" + tls_cfg "github.com/cortexproject/cortex/pkg/util/tls" ) var ( @@ -29,9 +29,9 @@ var ( // Config says where we can find the ruler userconfig. type Config struct { - ConfigsAPIURL flagext.URLValue `yaml:"configs_api_url"` - ClientTimeout time.Duration `yaml:"client_timeout"` // HTTP timeout duration for requests made to the Weave Cloud configs service. - TLS dstls.ClientConfig `yaml:",inline"` + ConfigsAPIURL flagext.URLValue `yaml:"configs_api_url"` + ClientTimeout time.Duration `yaml:"client_timeout"` // HTTP timeout duration for requests made to the Weave Cloud configs service. + TLS tls_cfg.ClientConfig `yaml:",inline"` } // RegisterFlagsWithPrefix adds the flags required to config this to the given FlagSet diff --git a/pkg/cortex/cortex.go b/pkg/cortex/cortex.go index 0a6ca6c2c1f..9f91ab0f421 100644 --- a/pkg/cortex/cortex.go +++ b/pkg/cortex/cortex.go @@ -13,7 +13,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/flagext" - "github.com/grafana/dskit/grpcutil" "github.com/grafana/dskit/kv/memberlist" "github.com/grafana/dskit/modules" "github.com/grafana/dskit/runtimeconfig" @@ -59,6 +58,7 @@ import ( "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" "github.com/cortexproject/cortex/pkg/util/fakeauth" + "github.com/cortexproject/cortex/pkg/util/grpcutil" util_log "github.com/cortexproject/cortex/pkg/util/log" "github.com/cortexproject/cortex/pkg/util/process" "github.com/cortexproject/cortex/pkg/util/validation" diff --git a/pkg/distributor/query.go b/pkg/distributor/query.go index e2d3e42d88a..0c2145cef8d 100644 --- a/pkg/distributor/query.go +++ b/pkg/distributor/query.go @@ -6,7 +6,6 @@ import ( "sort" "time" - "github.com/grafana/dskit/grpcutil" "github.com/opentracing/opentracing-go" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/model/labels" @@ -19,6 +18,7 @@ import ( "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" "github.com/cortexproject/cortex/pkg/util/extract" + "github.com/cortexproject/cortex/pkg/util/grpcutil" "github.com/cortexproject/cortex/pkg/util/limiter" "github.com/cortexproject/cortex/pkg/util/validation" ) diff --git a/pkg/frontend/v2/frontend.go b/pkg/frontend/v2/frontend.go index 540c515019e..2a1512e0c0f 100644 --- a/pkg/frontend/v2/frontend.go +++ b/pkg/frontend/v2/frontend.go @@ -12,7 +12,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/flagext" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/services" "github.com/opentracing/opentracing-go" "github.com/pkg/errors" @@ -24,6 +23,7 @@ import ( "github.com/cortexproject/cortex/pkg/frontend/v2/frontendv2pb" "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/tenant" + "github.com/cortexproject/cortex/pkg/util/grpcclient" "github.com/cortexproject/cortex/pkg/util/httpgrpcutil" ) diff --git a/pkg/ingester/client/client.go b/pkg/ingester/client/client.go index 05f010d7175..6b017a20e26 100644 --- a/pkg/ingester/client/client.go +++ b/pkg/ingester/client/client.go @@ -4,11 +4,12 @@ import ( "flag" "github.com/go-kit/log" - "github.com/grafana/dskit/grpcclient" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "google.golang.org/grpc" "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/cortexproject/cortex/pkg/util/grpcclient" ) var ingesterClientRequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ diff --git a/pkg/ingester/client/cortex_util_test.go b/pkg/ingester/client/cortex_util_test.go index 06f6e94a8c2..e9c7a014df8 100644 --- a/pkg/ingester/client/cortex_util_test.go +++ b/pkg/ingester/client/cortex_util_test.go @@ -7,13 +7,13 @@ import ( "testing" "time" - "github.com/grafana/dskit/grpcutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" + "github.com/cortexproject/cortex/pkg/util/grpcutil" "github.com/cortexproject/cortex/pkg/util/test" ) diff --git a/pkg/querier/store_gateway_client.go b/pkg/querier/store_gateway_client.go index 528c3c2fd93..b84a87846c1 100644 --- a/pkg/querier/store_gateway_client.go +++ b/pkg/querier/store_gateway_client.go @@ -5,8 +5,6 @@ import ( "time" "github.com/go-kit/log" - "github.com/grafana/dskit/crypto/tls" - "github.com/grafana/dskit/grpcclient" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -15,6 +13,8 @@ import ( "github.com/cortexproject/cortex/pkg/ring/client" "github.com/cortexproject/cortex/pkg/storegateway/storegatewaypb" + "github.com/cortexproject/cortex/pkg/util/grpcclient" + "github.com/cortexproject/cortex/pkg/util/tls" ) func newStoreGatewayClientFactory(clientCfg grpcclient.Config, reg prometheus.Registerer) client.PoolFactory { diff --git a/pkg/querier/store_gateway_client_test.go b/pkg/querier/store_gateway_client_test.go index afe8e897e7a..3583bee7e23 100644 --- a/pkg/querier/store_gateway_client_test.go +++ b/pkg/querier/store_gateway_client_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/grafana/dskit/flagext" - "github.com/grafana/dskit/grpcclient" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" @@ -16,6 +15,7 @@ import ( "google.golang.org/grpc" "github.com/cortexproject/cortex/pkg/storegateway/storegatewaypb" + "github.com/cortexproject/cortex/pkg/util/grpcclient" ) func Test_newStoreGatewayClientFactory(t *testing.T) { diff --git a/pkg/querier/worker/scheduler_processor.go b/pkg/querier/worker/scheduler_processor.go index ee0b2028425..b835485a2f5 100644 --- a/pkg/querier/worker/scheduler_processor.go +++ b/pkg/querier/worker/scheduler_processor.go @@ -9,7 +9,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/grafana/dskit/backoff" - "github.com/grafana/dskit/grpcclient" dsmiddleware "github.com/grafana/dskit/middleware" "github.com/grafana/dskit/services" otgrpc "github.com/opentracing-contrib/go-grpc" @@ -26,6 +25,7 @@ import ( querier_stats "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/ring/client" "github.com/cortexproject/cortex/pkg/scheduler/schedulerpb" + "github.com/cortexproject/cortex/pkg/util/grpcclient" "github.com/cortexproject/cortex/pkg/util/httpgrpcutil" util_log "github.com/cortexproject/cortex/pkg/util/log" ) diff --git a/pkg/querier/worker/worker.go b/pkg/querier/worker/worker.go index 21e7c985665..9180a7b4a2c 100644 --- a/pkg/querier/worker/worker.go +++ b/pkg/querier/worker/worker.go @@ -9,7 +9,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/services" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -17,6 +16,7 @@ import ( "google.golang.org/grpc" "github.com/cortexproject/cortex/pkg/util" + "github.com/cortexproject/cortex/pkg/util/grpcclient" ) type Config struct { diff --git a/pkg/ruler/client_pool.go b/pkg/ruler/client_pool.go index 717d154e0fc..5eae26644bc 100644 --- a/pkg/ruler/client_pool.go +++ b/pkg/ruler/client_pool.go @@ -4,7 +4,6 @@ import ( "time" "github.com/go-kit/log" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/services" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -13,6 +12,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "github.com/cortexproject/cortex/pkg/ring/client" + "github.com/cortexproject/cortex/pkg/util/grpcclient" ) // ClientsPool is the interface used to get the client from the pool for a specified address. diff --git a/pkg/ruler/client_pool_test.go b/pkg/ruler/client_pool_test.go index 2f4998f1cf2..14d7cfb8e71 100644 --- a/pkg/ruler/client_pool_test.go +++ b/pkg/ruler/client_pool_test.go @@ -6,13 +6,14 @@ import ( "testing" "github.com/grafana/dskit/flagext" - "github.com/grafana/dskit/grpcclient" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaveworks/common/user" "google.golang.org/grpc" + + "github.com/cortexproject/cortex/pkg/util/grpcclient" ) func Test_newRulerClientFactory(t *testing.T) { diff --git a/pkg/ruler/notifier.go b/pkg/ruler/notifier.go index d8f5a55b7d9..3b90b25ad35 100644 --- a/pkg/ruler/notifier.go +++ b/pkg/ruler/notifier.go @@ -11,7 +11,6 @@ import ( gklog "github.com/go-kit/log" "github.com/go-kit/log/level" - "github.com/grafana/dskit/crypto/tls" config_util "github.com/prometheus/common/config" "github.com/prometheus/common/model" "github.com/prometheus/prometheus/config" @@ -20,6 +19,7 @@ import ( "github.com/prometheus/prometheus/notifier" "github.com/cortexproject/cortex/pkg/util" + "github.com/cortexproject/cortex/pkg/util/tls" ) type NotifierConfig struct { diff --git a/pkg/ruler/ruler.go b/pkg/ruler/ruler.go index f0398c707db..b1f7c156acb 100644 --- a/pkg/ruler/ruler.go +++ b/pkg/ruler/ruler.go @@ -16,7 +16,6 @@ import ( "github.com/go-kit/log/level" "github.com/grafana/dskit/concurrency" "github.com/grafana/dskit/flagext" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/kv" "github.com/grafana/dskit/services" "github.com/pkg/errors" @@ -35,6 +34,7 @@ import ( "github.com/cortexproject/cortex/pkg/ruler/rulestore" "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" + "github.com/cortexproject/cortex/pkg/util/grpcclient" util_log "github.com/cortexproject/cortex/pkg/util/log" "github.com/cortexproject/cortex/pkg/util/validation" ) diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 3d3dbdc3441..543a567562f 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -10,7 +10,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/services" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" @@ -27,6 +26,7 @@ import ( "github.com/cortexproject/cortex/pkg/scheduler/schedulerpb" "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" + "github.com/cortexproject/cortex/pkg/util/grpcclient" "github.com/cortexproject/cortex/pkg/util/httpgrpcutil" "github.com/cortexproject/cortex/pkg/util/validation" ) diff --git a/pkg/util/dns_watcher.go b/pkg/util/dns_watcher.go index 4b37852e445..6ff76de33c8 100644 --- a/pkg/util/dns_watcher.go +++ b/pkg/util/dns_watcher.go @@ -5,10 +5,10 @@ import ( "fmt" "time" - "github.com/grafana/dskit/grpcutil" "github.com/grafana/dskit/services" "github.com/pkg/errors" + "github.com/cortexproject/cortex/pkg/util/grpcutil" util_log "github.com/cortexproject/cortex/pkg/util/log" ) diff --git a/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go b/pkg/util/grpcclient/backoff_retry.go similarity index 99% rename from vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go rename to pkg/util/grpcclient/backoff_retry.go index 21abbb78656..70ab38c6b73 100644 --- a/vendor/github.com/grafana/dskit/grpcclient/backoff_retry.go +++ b/pkg/util/grpcclient/backoff_retry.go @@ -3,11 +3,10 @@ package grpcclient import ( "context" + "github.com/grafana/dskit/backoff" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - - "github.com/grafana/dskit/backoff" ) // NewBackoffRetry gRPC middleware. diff --git a/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go b/pkg/util/grpcclient/grpcclient.go similarity index 97% rename from vendor/github.com/grafana/dskit/grpcclient/grpcclient.go rename to pkg/util/grpcclient/grpcclient.go index 094337f5d2c..3dd1c0ceb1b 100644 --- a/vendor/github.com/grafana/dskit/grpcclient/grpcclient.go +++ b/pkg/util/grpcclient/grpcclient.go @@ -5,15 +5,15 @@ import ( "time" "github.com/go-kit/log" + "github.com/grafana/dskit/backoff" middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/keepalive" - "github.com/grafana/dskit/backoff" - "github.com/grafana/dskit/crypto/tls" - "github.com/grafana/dskit/grpcencoding/snappy" + "github.com/cortexproject/cortex/pkg/util/grpcencoding/snappy" + "github.com/cortexproject/cortex/pkg/util/tls" ) // Config for a gRPC client. diff --git a/vendor/github.com/grafana/dskit/grpcclient/instrumentation.go b/pkg/util/grpcclient/instrumentation.go similarity index 99% rename from vendor/github.com/grafana/dskit/grpcclient/instrumentation.go rename to pkg/util/grpcclient/instrumentation.go index b22a5883405..cf3dd392d6a 100644 --- a/vendor/github.com/grafana/dskit/grpcclient/instrumentation.go +++ b/pkg/util/grpcclient/instrumentation.go @@ -1,13 +1,12 @@ package grpcclient import ( + dsmiddleware "github.com/grafana/dskit/middleware" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/weaveworks/common/middleware" "google.golang.org/grpc" - - dsmiddleware "github.com/grafana/dskit/middleware" ) func Instrument(requestDuration *prometheus.HistogramVec) ([]grpc.UnaryClientInterceptor, []grpc.StreamClientInterceptor) { diff --git a/vendor/github.com/grafana/dskit/grpcclient/ratelimit.go b/pkg/util/grpcclient/ratelimit.go similarity index 100% rename from vendor/github.com/grafana/dskit/grpcclient/ratelimit.go rename to pkg/util/grpcclient/ratelimit.go diff --git a/pkg/util/grpcclient/ratelimit_test.go b/pkg/util/grpcclient/ratelimit_test.go new file mode 100644 index 00000000000..6a8d6345b9b --- /dev/null +++ b/pkg/util/grpcclient/ratelimit_test.go @@ -0,0 +1,36 @@ +package grpcclient_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/cortexproject/cortex/pkg/util/grpcclient" +) + +func TestRateLimiterFailureResultsInResourceExhaustedError(t *testing.T) { + config := grpcclient.Config{ + RateLimitBurst: 0, + RateLimit: 0, + } + conn := grpc.ClientConn{} + invoker := func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error { + return nil + } + + limiter := grpcclient.NewRateLimiter(&config) + err := limiter(context.Background(), "methodName", "", "expectedReply", &conn, invoker) + + if se, ok := err.(interface { + GRPCStatus() *status.Status + }); ok { + assert.Equal(t, se.GRPCStatus().Code(), codes.ResourceExhausted) + assert.Equal(t, se.GRPCStatus().Message(), "rate: Wait(n=1) exceeds limiter's burst 0") + } else { + assert.Fail(t, "Could not convert error into expected Status type") + } +} diff --git a/vendor/github.com/grafana/dskit/grpcencoding/snappy/snappy.go b/pkg/util/grpcencoding/snappy/snappy.go similarity index 100% rename from vendor/github.com/grafana/dskit/grpcencoding/snappy/snappy.go rename to pkg/util/grpcencoding/snappy/snappy.go diff --git a/pkg/util/grpcencoding/snappy/snappy_test.go b/pkg/util/grpcencoding/snappy/snappy_test.go new file mode 100644 index 00000000000..d288c95c215 --- /dev/null +++ b/pkg/util/grpcencoding/snappy/snappy_test.go @@ -0,0 +1,70 @@ +package snappy + +import ( + "bytes" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSnappy(t *testing.T) { + c := newCompressor() + assert.Equal(t, "snappy", c.Name()) + + tests := []struct { + test string + input string + }{ + {"empty", ""}, + {"short", "hello world"}, + {"long", strings.Repeat("123456789", 1024)}, + } + for _, test := range tests { + t.Run(test.test, func(t *testing.T) { + var buf bytes.Buffer + // Compress + w, err := c.Compress(&buf) + require.NoError(t, err) + n, err := w.Write([]byte(test.input)) + require.NoError(t, err) + assert.Len(t, test.input, n) + err = w.Close() + require.NoError(t, err) + // Decompress + r, err := c.Decompress(&buf) + require.NoError(t, err) + out, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, test.input, string(out)) + }) + } +} + +func BenchmarkSnappyCompress(b *testing.B) { + data := []byte(strings.Repeat("123456789", 1024)) + c := newCompressor() + b.ResetTimer() + for i := 0; i < b.N; i++ { + w, _ := c.Compress(io.Discard) + _, _ = w.Write(data) + _ = w.Close() + } +} + +func BenchmarkSnappyDecompress(b *testing.B) { + data := []byte(strings.Repeat("123456789", 1024)) + c := newCompressor() + var buf bytes.Buffer + w, _ := c.Compress(&buf) + _, _ = w.Write(data) + reader := bytes.NewReader(buf.Bytes()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + r, _ := c.Decompress(reader) + _, _ = io.ReadAll(r) + _, _ = reader.Seek(0, io.SeekStart) + } +} diff --git a/vendor/github.com/grafana/dskit/grpcutil/dns_resolver.go b/pkg/util/grpcutil/dns_resolver.go similarity index 100% rename from vendor/github.com/grafana/dskit/grpcutil/dns_resolver.go rename to pkg/util/grpcutil/dns_resolver.go diff --git a/vendor/github.com/grafana/dskit/grpcutil/health_check.go b/pkg/util/grpcutil/health_check.go similarity index 99% rename from vendor/github.com/grafana/dskit/grpcutil/health_check.go rename to pkg/util/grpcutil/health_check.go index 2b567b36804..a566475b2ff 100644 --- a/vendor/github.com/grafana/dskit/grpcutil/health_check.go +++ b/pkg/util/grpcutil/health_check.go @@ -4,10 +4,9 @@ import ( "context" "github.com/gogo/status" + "github.com/grafana/dskit/services" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/grafana/dskit/services" ) // HealthCheck fulfills the grpc_health_v1.HealthServer interface by ensuring diff --git a/pkg/util/grpcutil/health_check_test.go b/pkg/util/grpcutil/health_check_test.go new file mode 100644 index 00000000000..048da197ffb --- /dev/null +++ b/pkg/util/grpcutil/health_check_test.go @@ -0,0 +1,145 @@ +package grpcutil + +import ( + "context" + "testing" + + "github.com/grafana/dskit/services" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHealthCheck_isHealthy(t *testing.T) { + tests := map[string]struct { + states []services.State + expected bool + }{ + "all services are new": { + states: []services.State{services.New, services.New}, + expected: false, + }, + "all services are starting": { + states: []services.State{services.Starting, services.Starting}, + expected: false, + }, + "some services are starting and some running": { + states: []services.State{services.Starting, services.Running}, + expected: false, + }, + "all services are running": { + states: []services.State{services.Running, services.Running}, + expected: true, + }, + "some services are stopping": { + states: []services.State{services.Running, services.Stopping}, + expected: true, + }, + "some services are terminated while others running": { + states: []services.State{services.Running, services.Terminated}, + expected: true, + }, + "all services are stopping": { + states: []services.State{services.Stopping, services.Stopping}, + expected: true, + }, + "some services are terminated while others stopping": { + states: []services.State{services.Stopping, services.Terminated}, + expected: true, + }, + "a service has failed while others are running": { + states: []services.State{services.Running, services.Failed}, + expected: false, + }, + "all services are terminated": { + states: []services.State{services.Terminated, services.Terminated}, + expected: false, + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + var svcs []services.Service + for range testData.states { + svcs = append(svcs, &mockService{}) + } + + sm, err := services.NewManager(svcs...) + require.NoError(t, err) + + // Switch the state of each mocked services. + for i, s := range svcs { + s.(*mockService).switchState(testData.states[i]) + } + + h := NewHealthCheck(sm) + assert.Equal(t, testData.expected, h.isHealthy()) + }) + } +} + +type mockService struct { + services.Service + state services.State + listeners []services.Listener +} + +func (s *mockService) switchState(desiredState services.State) { + // Simulate all the states between the current state and the desired one. + orderedStates := []services.State{services.New, services.Starting, services.Running, services.Failed, services.Stopping, services.Terminated} + simulationStarted := false + + for _, orderedState := range orderedStates { + // Skip until we reach the current state. + if !simulationStarted && orderedState != s.state { + continue + } + + // Start the simulation once we reach the current state. + if orderedState == s.state { + simulationStarted = true + continue + } + + // Skip the failed state, unless it's the desired one. + if orderedState == services.Failed && desiredState != services.Failed { + continue + } + + s.state = orderedState + + // Synchronously call listeners to avoid flaky tests. + for _, listener := range s.listeners { + switch orderedState { + case services.Starting: + listener.Starting() + case services.Running: + listener.Running() + case services.Stopping: + listener.Stopping(services.Running) + case services.Failed: + listener.Failed(services.Running, errors.New("mocked error")) + case services.Terminated: + listener.Terminated(services.Stopping) + } + } + + if orderedState == desiredState { + break + } + } +} + +func (s *mockService) State() services.State { + return s.state +} + +func (s *mockService) AddListener(listener services.Listener) { + s.listeners = append(s.listeners, listener) +} + +func (s *mockService) StartAsync(_ context.Context) error { return nil } +func (s *mockService) AwaitRunning(_ context.Context) error { return nil } +func (s *mockService) StopAsync() {} +func (s *mockService) AwaitTerminated(_ context.Context) error { return nil } +func (s *mockService) FailureCase() error { return nil } diff --git a/vendor/github.com/grafana/dskit/grpcutil/naming.go b/pkg/util/grpcutil/naming.go similarity index 100% rename from vendor/github.com/grafana/dskit/grpcutil/naming.go rename to pkg/util/grpcutil/naming.go diff --git a/vendor/github.com/grafana/dskit/grpcutil/util.go b/pkg/util/grpcutil/util.go similarity index 100% rename from vendor/github.com/grafana/dskit/grpcutil/util.go rename to pkg/util/grpcutil/util.go diff --git a/pkg/util/tls/test/tls_integration_test.go b/pkg/util/tls/test/tls_integration_test.go new file mode 100644 index 00000000000..d599ec9fa0d --- /dev/null +++ b/pkg/util/tls/test/tls_integration_test.go @@ -0,0 +1,567 @@ +package test + +import ( + "context" + "crypto/x509" + "crypto/x509/pkix" + "flag" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/gogo/status" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/weaveworks/common/server" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/cortexproject/cortex/integration/ca" + "github.com/cortexproject/cortex/pkg/util/grpcclient" + "github.com/cortexproject/cortex/pkg/util/tls" +) + +type tcIntegrationClientServer struct { + name string + tlsGrpcEnabled bool + tlsConfig tls.ClientConfig + httpExpectError func(*testing.T, error) + grpcExpectError func(*testing.T, error) +} + +type grpcHealthCheck struct { + healthy bool +} + +func (h *grpcHealthCheck) Check(_ context.Context, _ *grpc_health_v1.HealthCheckRequest) (*grpc_health_v1.HealthCheckResponse, error) { + if !h.healthy { + return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil + } + + return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil +} + +func (h *grpcHealthCheck) Watch(_ *grpc_health_v1.HealthCheckRequest, _ grpc_health_v1.Health_WatchServer) error { + return status.Error(codes.Unimplemented, "Watching is not supported") +} + +func getLocalHostPort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, err + } + + if err := l.Close(); err != nil { + return 0, err + } + return l.Addr().(*net.TCPAddr).Port, nil +} + +func newIntegrationClientServer( + t *testing.T, + cfg server.Config, + tcs []tcIntegrationClientServer, +) { + // server registers some metrics to default registry + savedRegistry := prometheus.DefaultRegisterer + prometheus.DefaultRegisterer = prometheus.NewRegistry() + defer func() { + prometheus.DefaultRegisterer = savedRegistry + }() + + grpcPort, err := getLocalHostPort() + require.NoError(t, err) + httpPort, err := getLocalHostPort() + require.NoError(t, err) + + cfg.HTTPListenPort = httpPort + cfg.GRPCListenPort = grpcPort + + serv, err := server.New(cfg) + require.NoError(t, err) + + serv.HTTP.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "OK") + }) + + grpc_health_v1.RegisterHealthServer(serv.GRPC, &grpcHealthCheck{true}) + + go func() { + err := serv.Run() + require.NoError(t, err) + }() + + httpURL := fmt.Sprintf("https://localhost:%d/hello", httpPort) + grpcHost := fmt.Sprintf("localhost:%d", grpcPort) + + for _, tc := range tcs { + tlsClientConfig, err := tc.tlsConfig.GetTLSConfig() + require.NoError(t, err) + + // HTTP + t.Run("HTTP/"+tc.name, func(t *testing.T) { + transport := &http.Transport{TLSClientConfig: tlsClientConfig} + client := &http.Client{Transport: transport} + + resp, err := client.Get(httpURL) + if err == nil { + defer resp.Body.Close() + } + if tc.httpExpectError != nil { + tc.httpExpectError(t, err) + return + } + if err != nil { + assert.NoError(t, err, tc.name) + return + } + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err, tc.name) + + assert.Equal(t, []byte("OK"), body, tc.name) + + }) + + // GRPC + t.Run("GRPC/"+tc.name, func(t *testing.T) { + clientConfig := grpcclient.Config{} + clientConfig.RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + clientConfig.TLSEnabled = tc.tlsGrpcEnabled + clientConfig.TLS = tc.tlsConfig + + dialOptions, err := clientConfig.DialOption(nil, nil) + assert.NoError(t, err, tc.name) + dialOptions = append([]grpc.DialOption{grpc.WithDefaultCallOptions(clientConfig.CallOptions()...)}, dialOptions...) + + conn, err := grpc.Dial(grpcHost, dialOptions...) + assert.NoError(t, err, tc.name) + require.NoError(t, err, tc.name) + require.NoError(t, err, tc.name) + + client := grpc_health_v1.NewHealthClient(conn) + + // TODO: Investigate why the client doesn't really receive the + // error about the bad certificate from the server side and just + // see connection closed instead + resp, err := client.Check(context.TODO(), &grpc_health_v1.HealthCheckRequest{}) + if tc.grpcExpectError != nil { + tc.grpcExpectError(t, err) + return + } + assert.NoError(t, err) + if err == nil { + assert.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, resp.Status) + } + }) + + } + + serv.Shutdown() +} + +func TestServerWithoutTlsEnabled(t *testing.T) { + cfg := server.Config{} + (&cfg).RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + newIntegrationClientServer( + t, + cfg, + []tcIntegrationClientServer{ + { + name: "no-config", + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("http: server gave HTTP response to HTTPS client"), + grpcExpectError: nil, + }, + { + name: "tls-enable", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("http: server gave HTTP response to HTTPS client"), + grpcExpectError: errorContainsString("transport: authentication handshake failed: tls: first record does not look like a TLS handshake"), + }, + }, + ) +} + +func TestServerWithLocalhostCertNoClientCertAuth(t *testing.T) { + certs := setupCertificates(t) + + cfg := server.Config{} + (&cfg).RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + unavailableDescErr := errorContainsString("rpc error: code = Unavailable desc =") + + cfg.HTTPTLSConfig.TLSCertPath = certs.serverCertFile + cfg.HTTPTLSConfig.TLSKeyPath = certs.serverKeyFile + cfg.GRPCTLSConfig.TLSCertPath = certs.serverCertFile + cfg.GRPCTLSConfig.TLSKeyPath = certs.serverKeyFile + + // Test a TLS server with localhost cert without any client certificate enforcement + newIntegrationClientServer( + t, + cfg, + []tcIntegrationClientServer{ + { + name: "no-config", + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("x509: certificate signed by unknown authority"), + // For GRPC we expect this error as we try to connect without TLS to a TLS enabled server + grpcExpectError: unavailableDescErr, + }, + { + name: "grpc-tls-enabled", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("x509: certificate signed by unknown authority"), + grpcExpectError: errorContainsString("x509: certificate signed by unknown authority"), + }, + { + name: "tls-skip-verify", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + InsecureSkipVerify: true, + }, + }, + { + name: "tls-skip-verify-no-grpc-tls-enabled", + tlsGrpcEnabled: false, + tlsConfig: tls.ClientConfig{ + InsecureSkipVerify: true, + }, + grpcExpectError: unavailableDescErr, + }, + { + name: "ca-path-set", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + }, + }, + { + name: "ca-path-no-grpc-tls-enabled", + tlsGrpcEnabled: false, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + }, + grpcExpectError: unavailableDescErr, + }, + }, + ) +} + +func TestServerWithoutLocalhostCertNoClientCertAuth(t *testing.T) { + certs := setupCertificates(t) + + cfg := server.Config{} + (&cfg).RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + unavailableDescErr := errorContainsString("rpc error: code = Unavailable desc =") + + // Test a TLS server without localhost cert without any client certificate enforcement + cfg.HTTPTLSConfig.TLSCertPath = certs.serverNoLocalhostCertFile + cfg.HTTPTLSConfig.TLSKeyPath = certs.serverNoLocalhostKeyFile + cfg.GRPCTLSConfig.TLSCertPath = certs.serverNoLocalhostCertFile + cfg.GRPCTLSConfig.TLSKeyPath = certs.serverNoLocalhostKeyFile + newIntegrationClientServer( + t, + cfg, + []tcIntegrationClientServer{ + { + name: "no-config", + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("x509: certificate is valid for my-other-name, not localhost"), + // For GRPC we expect this error as we try to connect without TLS to a TLS enabled server + grpcExpectError: unavailableDescErr, + }, + { + name: "grpc-tls-enabled", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{}, + httpExpectError: errorContainsString("x509: certificate is valid for my-other-name, not localhost"), + grpcExpectError: errorContainsString("x509: certificate is valid for my-other-name, not localhost"), + }, + { + name: "ca-path", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + }, + httpExpectError: errorContainsString("x509: certificate is valid for my-other-name, not localhost"), + grpcExpectError: errorContainsString("x509: certificate is valid for my-other-name, not localhost"), + }, + { + name: "server-name", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + ServerName: "my-other-name", + }, + }, + { + name: "tls-skip-verify", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + InsecureSkipVerify: true, + }, + }, + }, + ) +} + +func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA1(t *testing.T) { + certs := setupCertificates(t) + + cfg := server.Config{} + (&cfg).RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + unavailableDescErr := errorContainsString("rpc error: code = Unavailable desc =") + + // Test a TLS server with localhost cert with client certificate enforcement through client CA 1 + cfg.HTTPTLSConfig.TLSCertPath = certs.serverCertFile + cfg.HTTPTLSConfig.TLSKeyPath = certs.serverKeyFile + cfg.HTTPTLSConfig.ClientCAs = certs.clientCA1CertFile + cfg.HTTPTLSConfig.ClientAuth = "RequireAndVerifyClientCert" + cfg.GRPCTLSConfig.TLSCertPath = certs.serverCertFile + cfg.GRPCTLSConfig.TLSKeyPath = certs.serverKeyFile + cfg.GRPCTLSConfig.ClientCAs = certs.clientCA1CertFile + cfg.GRPCTLSConfig.ClientAuth = "RequireAndVerifyClientCert" + + // TODO: Investigate why we don't really receive the error about the + // bad certificate from the server side and just see connection + // closed/reset instead + badCertErr := errorContainsString("remote error: tls: bad certificate") + newIntegrationClientServer( + t, + cfg, + []tcIntegrationClientServer{ + { + name: "tls-skip-verify", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + InsecureSkipVerify: true, + }, + httpExpectError: badCertErr, + grpcExpectError: unavailableDescErr, + }, + { + name: "ca-path", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + }, + httpExpectError: badCertErr, + grpcExpectError: unavailableDescErr, + }, + { + name: "ca-path-and-client-cert-ca1", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + CertPath: certs.client1CertFile, + KeyPath: certs.client1KeyFile, + }, + }, + { + name: "tls-skip-verify-and-client-cert-ca1", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + InsecureSkipVerify: true, + CertPath: certs.client1CertFile, + KeyPath: certs.client1KeyFile, + }, + }, + { + name: "ca-cert-and-client-cert-ca2", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + CertPath: certs.client2CertFile, + KeyPath: certs.client2KeyFile, + }, + httpExpectError: badCertErr, + grpcExpectError: unavailableDescErr, + }, + }, + ) +} + +func TestTLSServerWithLocalhostCertWithClientCertificateEnforcementUsingClientCA2(t *testing.T) { + certs := setupCertificates(t) + + cfg := server.Config{} + (&cfg).RegisterFlags(flag.NewFlagSet("fake", flag.ContinueOnError)) + + // Test a TLS server with localhost cert with client certificate enforcement through client CA 1 + cfg.HTTPTLSConfig.TLSCertPath = certs.serverCertFile + cfg.HTTPTLSConfig.TLSKeyPath = certs.serverKeyFile + cfg.HTTPTLSConfig.ClientCAs = certs.clientCABothCertFile + cfg.HTTPTLSConfig.ClientAuth = "RequireAndVerifyClientCert" + cfg.GRPCTLSConfig.TLSCertPath = certs.serverCertFile + cfg.GRPCTLSConfig.TLSKeyPath = certs.serverKeyFile + cfg.GRPCTLSConfig.ClientCAs = certs.clientCABothCertFile + cfg.GRPCTLSConfig.ClientAuth = "RequireAndVerifyClientCert" + + newIntegrationClientServer( + t, + cfg, + []tcIntegrationClientServer{ + { + name: "ca-cert-and-client-cert-ca1", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + CertPath: certs.client1CertFile, + KeyPath: certs.client1KeyFile, + }, + }, + { + name: "ca-cert-and-client-cert-ca2", + tlsGrpcEnabled: true, + tlsConfig: tls.ClientConfig{ + CAPath: certs.caCertFile, + CertPath: certs.client2CertFile, + KeyPath: certs.client2KeyFile, + }, + }, + }, + ) +} + +func setupCertificates(t *testing.T) keyMaterial { + testCADir, err := ioutil.TempDir("", "cortex-ca") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, os.RemoveAll(testCADir)) + }) + + // create server side CA + + testCA := ca.New("Cortex Test") + caCertFile := filepath.Join(testCADir, "ca.crt") + require.NoError(t, testCA.WriteCACertificate(caCertFile)) + + serverCertFile := filepath.Join(testCADir, "server.crt") + serverKeyFile := filepath.Join(testCADir, "server.key") + require.NoError(t, testCA.WriteCertificate( + &x509.Certificate{ + Subject: pkix.Name{CommonName: "server"}, + DNSNames: []string{"localhost", "my-other-name"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }, + serverCertFile, + serverKeyFile, + )) + + serverNoLocalhostCertFile := filepath.Join(testCADir, "server-no-localhost.crt") + serverNoLocalhostKeyFile := filepath.Join(testCADir, "server-no-localhost.key") + require.NoError(t, testCA.WriteCertificate( + &x509.Certificate{ + Subject: pkix.Name{CommonName: "server-no-localhost"}, + DNSNames: []string{"my-other-name"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }, + serverNoLocalhostCertFile, + serverNoLocalhostKeyFile, + )) + + // create client CAs + testClientCA1 := ca.New("Cortex Test Client CA 1") + testClientCA2 := ca.New("Cortex Test Client CA 2") + + clientCA1CertFile := filepath.Join(testCADir, "ca-client-1.crt") + require.NoError(t, testClientCA1.WriteCACertificate(clientCA1CertFile)) + clientCA2CertFile := filepath.Join(testCADir, "ca-client-2.crt") + require.NoError(t, testClientCA2.WriteCACertificate(clientCA2CertFile)) + + // create a ca file with both certs + clientCABothCertFile := filepath.Join(testCADir, "ca-client-both.crt") + func() { + src1, err := os.Open(clientCA1CertFile) + require.NoError(t, err) + defer src1.Close() + src2, err := os.Open(clientCA2CertFile) + require.NoError(t, err) + defer src2.Close() + + dst, err := os.Create(clientCABothCertFile) + require.NoError(t, err) + defer dst.Close() + + _, err = io.Copy(dst, src1) + require.NoError(t, err) + _, err = io.Copy(dst, src2) + require.NoError(t, err) + + }() + + client1CertFile := filepath.Join(testCADir, "client-1.crt") + client1KeyFile := filepath.Join(testCADir, "client-1.key") + require.NoError(t, testClientCA1.WriteCertificate( + &x509.Certificate{ + Subject: pkix.Name{CommonName: "client-1"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, + client1CertFile, + client1KeyFile, + )) + + client2CertFile := filepath.Join(testCADir, "client-2.crt") + client2KeyFile := filepath.Join(testCADir, "client-2.key") + require.NoError(t, testClientCA2.WriteCertificate( + &x509.Certificate{ + Subject: pkix.Name{CommonName: "client-2"}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, + client2CertFile, + client2KeyFile, + )) + + return keyMaterial{ + caCertFile: caCertFile, + serverCertFile: serverCertFile, + serverKeyFile: serverKeyFile, + serverNoLocalhostCertFile: serverNoLocalhostCertFile, + serverNoLocalhostKeyFile: serverNoLocalhostKeyFile, + clientCA1CertFile: clientCA1CertFile, + clientCABothCertFile: clientCABothCertFile, + client1CertFile: client1CertFile, + client1KeyFile: client1KeyFile, + client2CertFile: client2CertFile, + client2KeyFile: client2KeyFile, + } +} + +type keyMaterial struct { + caCertFile string + serverCertFile string + serverKeyFile string + serverNoLocalhostCertFile string + serverNoLocalhostKeyFile string + clientCA1CertFile string + clientCABothCertFile string + client1CertFile string + client1KeyFile string + client2CertFile string + client2KeyFile string +} + +func errorContainsString(str string) func(*testing.T, error) { + return func(t *testing.T, err error) { + require.Error(t, err) + assert.Contains(t, err.Error(), str) + } +} diff --git a/pkg/util/tls/tls.go b/pkg/util/tls/tls.go new file mode 100644 index 00000000000..9886b208ddc --- /dev/null +++ b/pkg/util/tls/tls.go @@ -0,0 +1,87 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "flag" + "io/ioutil" + + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// ClientConfig is the config for client TLS. +type ClientConfig struct { + CertPath string `yaml:"tls_cert_path"` + KeyPath string `yaml:"tls_key_path"` + CAPath string `yaml:"tls_ca_path"` + ServerName string `yaml:"tls_server_name"` + InsecureSkipVerify bool `yaml:"tls_insecure_skip_verify"` +} + +var ( + errKeyMissing = errors.New("certificate given but no key configured") + errCertMissing = errors.New("key given but no certificate configured") +) + +// RegisterFlagsWithPrefix registers flags with prefix. +func (cfg *ClientConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { + f.StringVar(&cfg.CertPath, prefix+".tls-cert-path", "", "Path to the client certificate file, which will be used for authenticating with the server. Also requires the key path to be configured.") + f.StringVar(&cfg.KeyPath, prefix+".tls-key-path", "", "Path to the key file for the client certificate. Also requires the client certificate to be configured.") + f.StringVar(&cfg.CAPath, prefix+".tls-ca-path", "", "Path to the CA certificates file to validate server certificate against. If not set, the host's root CA certificates are used.") + f.StringVar(&cfg.ServerName, prefix+".tls-server-name", "", "Override the expected name on the server certificate.") + f.BoolVar(&cfg.InsecureSkipVerify, prefix+".tls-insecure-skip-verify", false, "Skip validating server certificate.") +} + +// GetTLSConfig initialises tls.Config from config options +func (cfg *ClientConfig) GetTLSConfig() (*tls.Config, error) { + config := &tls.Config{ + InsecureSkipVerify: cfg.InsecureSkipVerify, + ServerName: cfg.ServerName, + } + + // read ca certificates + if cfg.CAPath != "" { + var caCertPool *x509.CertPool + caCert, err := ioutil.ReadFile(cfg.CAPath) + if err != nil { + return nil, errors.Wrapf(err, "error loading ca cert: %s", cfg.CAPath) + } + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + config.RootCAs = caCertPool + } + + // read client certificate + if cfg.CertPath != "" || cfg.KeyPath != "" { + if cfg.CertPath == "" { + return nil, errCertMissing + } + if cfg.KeyPath == "" { + return nil, errKeyMissing + } + clientCert, err := tls.LoadX509KeyPair(cfg.CertPath, cfg.KeyPath) + if err != nil { + return nil, errors.Wrapf(err, "failed to load TLS certificate %s,%s", cfg.CertPath, cfg.KeyPath) + } + config.Certificates = []tls.Certificate{clientCert} + } + + return config, nil +} + +// GetGRPCDialOptions creates GRPC DialOptions for TLS +func (cfg *ClientConfig) GetGRPCDialOptions(enabled bool) ([]grpc.DialOption, error) { + if !enabled { + return []grpc.DialOption{grpc.WithInsecure()}, nil + } + + tlsConfig, err := cfg.GetTLSConfig() + if err != nil { + return nil, errors.Wrap(err, "error creating grpc dial options") + } + + return []grpc.DialOption{grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))}, nil +} diff --git a/pkg/util/tls/tls_test.go b/pkg/util/tls/tls_test.go new file mode 100644 index 00000000000..478ae02fc6f --- /dev/null +++ b/pkg/util/tls/tls_test.go @@ -0,0 +1,190 @@ +package tls + +import ( + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// certPEM and keyPEM are copied from the golang crypto/tls library +// https://github.com/golang/go/blob/7eb5941b95a588a23f18fa4c22fe42ff0119c311/src/crypto/tls/example_test.go#L127 +const certPEM = `-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----` +const keyPEM = `-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----` + +// caPEM is CA certificate of Let's Encrypt +// https://letsencrypt.org/certs/isrgrootx1.pem.txt +const caPEM = `-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE-----` + +type x509Paths struct { + cert string + key string + ca string +} + +func newTestX509Files(t *testing.T, cert, key, ca []byte) x509Paths { + + // create empty file + certsPath, err := ioutil.TempDir("", "*-x509") + require.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(certsPath) + }) + + paths := x509Paths{ + cert: fmt.Sprintf("%s/cert.pem", certsPath), + key: fmt.Sprintf("%s/key.pem", certsPath), + ca: fmt.Sprintf("%s/ca.pem", certsPath), + } + + if cert != nil { + err = ioutil.WriteFile(paths.cert, cert, 0600) + require.NoError(t, err) + } + + if key != nil { + err = ioutil.WriteFile(paths.key, key, 0600) + require.NoError(t, err) + } + + if ca != nil { + err = ioutil.WriteFile(paths.ca, ca, 0600) + require.NoError(t, err) + } + + return paths +} + +func TestGetTLSConfig_ClientCerts(t *testing.T) { + paths := newTestX509Files(t, []byte(certPEM), []byte(keyPEM), nil) + + // test working certificate passed + c := &ClientConfig{ + CertPath: paths.cert, + KeyPath: paths.key, + } + tlsConfig, err := c.GetTLSConfig() + assert.NoError(t, err) + assert.Equal(t, false, tlsConfig.InsecureSkipVerify, "make sure we default to not skip verification") + assert.Equal(t, 1, len(tlsConfig.Certificates), "ensure a certificate is returned") + + // expect error with key and cert swapped passed along + c = &ClientConfig{ + CertPath: paths.key, + KeyPath: paths.cert, + } + _, err = c.GetTLSConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to find certificate PEM data in certificate input, but did find a private key") + + // expect error with only key passed along + c = &ClientConfig{ + KeyPath: paths.key, + } + _, err = c.GetTLSConfig() + assert.EqualError(t, err, errCertMissing.Error()) + + // expect error with only cert passed along + c = &ClientConfig{ + CertPath: paths.cert, + } + _, err = c.GetTLSConfig() + assert.EqualError(t, err, errKeyMissing.Error()) +} + +func TestGetTLSConfig_CA(t *testing.T) { + paths := newTestX509Files(t, nil, nil, []byte(certPEM)) + + // test single ca passed + c := &ClientConfig{ + CAPath: paths.ca, + } + tlsConfig, err := c.GetTLSConfig() + assert.NoError(t, err) + assert.Equal(t, 1, len(tlsConfig.RootCAs.Subjects()), "ensure one CA is returned") + assert.Equal(t, false, tlsConfig.InsecureSkipVerify, "make sure we default to not skip verification") + + // test two cas passed + paths = newTestX509Files(t, nil, nil, []byte(certPEM+"\n"+caPEM)) + c = &ClientConfig{ + CAPath: paths.ca, + } + tlsConfig, err = c.GetTLSConfig() + assert.NoError(t, err) + assert.Equal(t, 2, len(tlsConfig.RootCAs.Subjects()), "ensure two CAs are returned") + assert.False(t, tlsConfig.InsecureSkipVerify, "make sure we default to not skip verification") + + // expect errors to be passed + c = &ClientConfig{ + CAPath: paths.ca + "not-existing", + } + _, err = c.GetTLSConfig() + assert.Error(t, err) + assert.Contains(t, err.Error(), "error loading ca cert") +} + +func TestGetTLSConfig_InsecureSkipVerify(t *testing.T) { + c := &ClientConfig{ + InsecureSkipVerify: true, + } + tlsConfig, err := c.GetTLSConfig() + assert.NoError(t, err) + assert.True(t, tlsConfig.InsecureSkipVerify) +} + +func TestGetTLSConfig_ServerName(t *testing.T) { + c := &ClientConfig{ + ServerName: "myserver.com", + } + tlsConfig, err := c.GetTLSConfig() + assert.NoError(t, err) + assert.Equal(t, "myserver.com", tlsConfig.ServerName) +} diff --git a/tools/blocksconvert/planprocessor/service.go b/tools/blocksconvert/planprocessor/service.go index db4474d4dda..5b0bec17d17 100644 --- a/tools/blocksconvert/planprocessor/service.go +++ b/tools/blocksconvert/planprocessor/service.go @@ -14,7 +14,6 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" - "github.com/grafana/dskit/grpcclient" "github.com/grafana/dskit/services" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -23,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" + "github.com/cortexproject/cortex/pkg/util/grpcclient" "github.com/cortexproject/cortex/tools/blocksconvert" ) diff --git a/vendor/modules.txt b/vendor/modules.txt index d92bd133335..3ba2e1a435a 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -361,9 +361,6 @@ github.com/grafana/dskit/backoff github.com/grafana/dskit/concurrency github.com/grafana/dskit/crypto/tls github.com/grafana/dskit/flagext -github.com/grafana/dskit/grpcclient -github.com/grafana/dskit/grpcencoding/snappy -github.com/grafana/dskit/grpcutil github.com/grafana/dskit/internal/math github.com/grafana/dskit/kv github.com/grafana/dskit/kv/codec @@ -379,6 +376,7 @@ github.com/grafana/dskit/runutil github.com/grafana/dskit/services github.com/grafana/dskit/test # github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 +## explicit github.com/grpc-ecosystem/go-grpc-middleware # github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0-rc.2.0.20201207153454-9f6bf00c00a7 github.com/grpc-ecosystem/go-grpc-middleware/v2