Skip to content

In Distribtor, pre-allocate buffer for reading protobufs #1719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/distributor/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ type Config struct {

HATrackerConfig HATrackerConfig `yaml:"ha_tracker,omitempty"`

MaxRecvMsgSize int `yaml:"max_send_msg_size"`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The yaml here is max_send_msg_size, while the CLI flag is max-recv-msg-size. Isn't contradictory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, could you open a PR/issue so that we don't lose track of this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure: #1755

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry!

RemoteTimeout time.Duration `yaml:"remote_timeout,omitempty"`
ExtraQueryDelay time.Duration `yaml:"extra_queue_delay,omitempty"`
LimiterReloadPeriod time.Duration `yaml:"limiter_reload_period,omitempty"`
Expand All @@ -142,6 +143,7 @@ func (cfg *Config) RegisterFlags(f *flag.FlagSet) {
cfg.HATrackerConfig.RegisterFlags(f)

f.BoolVar(&cfg.EnableBilling, "distributor.enable-billing", false, "Report number of ingested samples to billing system.")
f.IntVar(&cfg.MaxRecvMsgSize, "distributor.max-recv-msg-size", 100<<20, "remote_write API max receive message size (bytes).")
f.DurationVar(&cfg.RemoteTimeout, "distributor.remote-timeout", 2*time.Second, "Timeout for downstream ingesters.")
f.DurationVar(&cfg.ExtraQueryDelay, "distributor.extra-query-delay", 0, "Time to wait before sending more than the minimum successful query requests.")
f.DurationVar(&cfg.LimiterReloadPeriod, "distributor.limiter-reload-period", 5*time.Minute, "Period at which to reload user ingestion limits.")
Expand Down
2 changes: 1 addition & 1 deletion pkg/distributor/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (d *Distributor) PushHandler(w http.ResponseWriter, r *http.Request) {
compressionType := util.CompressionTypeFor(r.Header.Get("X-Prometheus-Remote-Write-Version"))
var req client.PreallocWriteRequest
req.Source = client.API
buf, err := util.ParseProtoReader(r.Context(), r.Body, &req, compressionType)
buf, err := util.ParseProtoReader(r.Context(), r.Body, int(r.ContentLength), d.cfg.MaxRecvMsgSize, &req, compressionType)
logger := util.WithContext(r.Context(), util.Logger)
if err != nil {
level.Error(logger).Log("err", err.Error())
Expand Down
12 changes: 10 additions & 2 deletions pkg/ingester/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (

// TestMarshall is useful to try out various optimisation on the unmarshalling code.
func TestMarshall(t *testing.T) {
const numSeries = 10
recorder := httptest.NewRecorder()
{
req := WriteRequest{}
for i := 0; i < 10; i++ {
for i := 0; i < numSeries; i++ {
req.Timeseries = append(req.Timeseries, PreallocTimeseries{
&TimeSeries{
Labels: []LabelAdapter{
Expand All @@ -32,8 +33,15 @@ func TestMarshall(t *testing.T) {
}

{
const (
tooSmallSize = 1
plentySize = 1024 * 1024
)
req := WriteRequest{}
_, err := util.ParseProtoReader(context.Background(), recorder.Body, &req, util.RawSnappy)
_, err := util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), tooSmallSize, &req, util.RawSnappy)
require.Error(t, err)
_, err = util.ParseProtoReader(context.Background(), recorder.Body, recorder.Body.Len(), plentySize, &req, util.RawSnappy)
require.NoError(t, err)
require.Equal(t, numSeries, len(req.Timeseries))
}
}
5 changes: 4 additions & 1 deletion pkg/querier/remote_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"github.com/prometheus/prometheus/storage"
)

// Queries are a set of matchers with time ranges - should not get into megabytes
const maxRemoteReadQuerySize = 1024 * 1024

// RemoteReadHandler handles Prometheus remote read requests.
func RemoteReadHandler(q storage.Queryable) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -17,7 +20,7 @@ func RemoteReadHandler(q storage.Queryable) http.Handler {
ctx := r.Context()
var req client.ReadRequest
logger := util.WithContext(r.Context(), util.Logger)
if _, err := util.ParseProtoReader(ctx, r.Body, &req, compressionType); err != nil {
if _, err := util.ParseProtoReader(ctx, r.Body, int(r.ContentLength), maxRemoteReadQuerySize, &req, compressionType); err != nil {
level.Error(logger).Log("err", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down
26 changes: 20 additions & 6 deletions pkg/util/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"

"github.com/blang/semver"
Expand Down Expand Up @@ -57,31 +56,46 @@ func CompressionTypeFor(version string) CompressionType {
}

// ParseProtoReader parses a compressed proto from an io.Reader.
func ParseProtoReader(ctx context.Context, reader io.Reader, req proto.Message, compression CompressionType) ([]byte, error) {
func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, req proto.Message, compression CompressionType) ([]byte, error) {
var body []byte
var err error
sp := opentracing.SpanFromContext(ctx)
if sp != nil {
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[start reading]"))
}
var buf bytes.Buffer
if expectedSize > 0 {
if expectedSize > maxSize {
return nil, fmt.Errorf("message expected size larger than max (%d vs %d)", expectedSize, maxSize)
}
buf.Grow(expectedSize + bytes.MinRead) // extra space guarantees no reallocation
}
switch compression {
case NoCompression:
body, err = ioutil.ReadAll(reader)
// Read from LimitReader with limit max+1. So if the underlying
// reader is over limit, the result will be bigger than max.
_, err = buf.ReadFrom(io.LimitReader(reader, int64(maxSize)+1))
body = buf.Bytes()
case FramedSnappy:
body, err = ioutil.ReadAll(snappy.NewReader(reader))
_, err = buf.ReadFrom(io.LimitReader(snappy.NewReader(reader), int64(maxSize)+1))
body = buf.Bytes()
case RawSnappy:
body, err = ioutil.ReadAll(reader)
_, err = buf.ReadFrom(reader)
body = buf.Bytes()
if sp != nil {
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[decompress]"),
otlog.Int("size", len(body)))
}
if err == nil {
if err == nil && len(body) <= maxSize {
body, err = snappy.Decode(nil, body)
}
}
if err != nil {
return nil, err
}
if len(body) > maxSize {
return nil, fmt.Errorf("received message larger than max (%d vs %d)", len(body), maxSize)
}

if sp != nil {
sp.LogFields(otlog.String("event", "util.ParseProtoRequest[unmarshal]"),
Expand Down