@@ -6609,3 +6609,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
6609
6609
}
6610
6610
}
6611
6611
}
6612
+
6613
+ func TestMaxBytesHandler (t * testing.T ) {
6614
+ setParallel (t )
6615
+ defer afterTest (t )
6616
+
6617
+ for _ , maxSize := range []int64 {100 , 1_000 , 1_000_000 } {
6618
+ for _ , requestSize := range []int64 {100 , 1_000 , 1_000_000 } {
6619
+ t .Run (fmt .Sprintf ("max size %d request size %d" , maxSize , requestSize ),
6620
+ func (t * testing.T ) {
6621
+ testMaxBytesHandler (t , maxSize , requestSize )
6622
+ })
6623
+ }
6624
+ }
6625
+ }
6626
+
6627
+ func testMaxBytesHandler (t * testing.T , maxSize , requestSize int64 ) {
6628
+ var (
6629
+ handlerN int64
6630
+ handlerErr error
6631
+ )
6632
+ echo := HandlerFunc (func (w ResponseWriter , r * Request ) {
6633
+ var buf bytes.Buffer
6634
+ handlerN , handlerErr = io .Copy (& buf , r .Body )
6635
+ io .Copy (w , & buf )
6636
+ })
6637
+
6638
+ ts := httptest .NewServer (MaxBytesHandler (echo , maxSize ))
6639
+ defer ts .Close ()
6640
+
6641
+ c := ts .Client ()
6642
+ var buf strings.Builder
6643
+ body := strings .NewReader (strings .Repeat ("a" , int (requestSize )))
6644
+ res , err := c .Post (ts .URL , "text/plain" , body )
6645
+ if err != nil {
6646
+ t .Errorf ("unexpected connection error: %v" , err )
6647
+ } else {
6648
+ _ , err = io .Copy (& buf , res .Body )
6649
+ res .Body .Close ()
6650
+ if err != nil {
6651
+ t .Errorf ("unexpected read error: %v" , err )
6652
+ }
6653
+ }
6654
+ if handlerN > maxSize {
6655
+ t .Errorf ("expected max request body %d; got %d" , maxSize , handlerN )
6656
+ }
6657
+ if requestSize > maxSize && handlerErr == nil {
6658
+ t .Error ("expected error on handler side; got nil" )
6659
+ }
6660
+ if requestSize <= maxSize {
6661
+ if handlerErr != nil {
6662
+ t .Errorf ("%d expected nil error on handler side; got %v" , requestSize , handlerErr )
6663
+ }
6664
+ if handlerN != requestSize {
6665
+ t .Errorf ("expected request of size %d; got %d" , requestSize , handlerN )
6666
+ }
6667
+ }
6668
+ if buf .Len () != int (handlerN ) {
6669
+ t .Errorf ("expected echo of size %d; got %d" , handlerN , buf .Len ())
6670
+ }
6671
+ }
0 commit comments