Skip to content

net/http: add MaxBytesHandler #48104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/net/http/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6609,3 +6609,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
}
}
}

func TestMaxBytesHandler(t *testing.T) {
setParallel(t)
defer afterTest(t)

for _, maxSize := range []int64{100, 1_000, 1_000_000} {
for _, requestSize := range []int64{100, 1_000, 1_000_000} {
t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
func(t *testing.T) {
testMaxBytesHandler(t, maxSize, requestSize)
})
}
}
}

func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
var (
handlerN int64
handlerErr error
)
echo := HandlerFunc(func(w ResponseWriter, r *Request) {
var buf bytes.Buffer
handlerN, handlerErr = io.Copy(&buf, r.Body)
io.Copy(w, &buf)
})

ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
defer ts.Close()

c := ts.Client()
var buf strings.Builder
body := strings.NewReader(strings.Repeat("a", int(requestSize)))
res, err := c.Post(ts.URL, "text/plain", body)
if err != nil {
t.Errorf("unexpected connection error: %v", err)
} else {
_, err = io.Copy(&buf, res.Body)
res.Body.Close()
if err != nil {
t.Errorf("unexpected read error: %v", err)
}
}
if handlerN > maxSize {
t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
}
if requestSize > maxSize && handlerErr == nil {
t.Error("expected error on handler side; got nil")
}
if requestSize <= maxSize {
if handlerErr != nil {
t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
}
if handlerN != requestSize {
t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
}
}
if buf.Len() != int(handlerN) {
t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
}
}
9 changes: 9 additions & 0 deletions src/net/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3567,3 +3567,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
}
return false
}

// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
func MaxBytesHandler(h Handler, n int64) Handler {
return HandlerFunc(func(w ResponseWriter, r *Request) {
r2 := *r

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, may I ask you what's the purpose of copying request here?
I already use my own implementation in middleware like this, for a while already, and didn't get into any issues so far:

func maxBodyMiddleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		r.Body = http.MaxBytesReader(w, r.Body, int64(cfg.MaxBodySize))

		next.ServeHTTP(w, r)
	})
}

so would like to know if there are some edge cases which I missed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docs for http.Handler,

Except for reading the body, handlers should not modify the provided Request.

You will often be fine modifying a request, but you aren’t supposed to do it because it might be a race condition if some other middleware happens to read the request at the same time.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you!

r2.Body = MaxBytesReader(w, r.Body, n)
h.ServeHTTP(w, &r2)
})
}