@@ -5059,3 +5059,146 @@ func TestTransportRequestReplayable(t *testing.T) {
5059
5059
})
5060
5060
}
5061
5061
}
5062
+
5063
+ // testMockTCPConn is a mock TCP connection used to test that
5064
+ // ReadFrom is called when sending the request body.
5065
+ type testMockTCPConn struct {
5066
+ * net.TCPConn
5067
+
5068
+ ReadFromCalled bool
5069
+ }
5070
+
5071
+ func (c * testMockTCPConn ) ReadFrom (r io.Reader ) (int64 , error ) {
5072
+ c .ReadFromCalled = true
5073
+ return c .TCPConn .ReadFrom (r )
5074
+ }
5075
+
5076
+ func TestTransportRequestWriteRoundTrip (t * testing.T ) {
5077
+ nBytes := int64 (1 << 10 )
5078
+ newFileFunc := func () (r io.Reader , done func (), err error ) {
5079
+ f , err := ioutil .TempFile ("" , "net-http-newfilefunc" )
5080
+ if err != nil {
5081
+ return nil , nil , err
5082
+ }
5083
+
5084
+ // Write some bytes to the file to enable reading.
5085
+ if _ , err := io .CopyN (f , rand .Reader , nBytes ); err != nil {
5086
+ return nil , nil , fmt .Errorf ("failed to write data to file: %v" , err )
5087
+ }
5088
+ if _ , err := f .Seek (0 , 0 ); err != nil {
5089
+ return nil , nil , fmt .Errorf ("failed to seek to front: %v" , err )
5090
+ }
5091
+
5092
+ done = func () {
5093
+ f .Close ()
5094
+ os .Remove (f .Name ())
5095
+ }
5096
+
5097
+ return f , done , nil
5098
+ }
5099
+
5100
+ newBufferFunc := func () (io.Reader , func (), error ) {
5101
+ return bytes .NewBuffer (make ([]byte , nBytes )), func () {}, nil
5102
+ }
5103
+
5104
+ cases := []struct {
5105
+ name string
5106
+ readerFunc func () (io.Reader , func (), error )
5107
+ contentLength int64
5108
+ expectedReadFrom bool
5109
+ }{
5110
+ {
5111
+ name : "file, length" ,
5112
+ readerFunc : newFileFunc ,
5113
+ contentLength : nBytes ,
5114
+ expectedReadFrom : true ,
5115
+ },
5116
+ {
5117
+ name : "file, no length" ,
5118
+ readerFunc : newFileFunc ,
5119
+ },
5120
+ {
5121
+ name : "file, negative length" ,
5122
+ readerFunc : newFileFunc ,
5123
+ contentLength : - 1 ,
5124
+ },
5125
+ {
5126
+ name : "buffer" ,
5127
+ contentLength : nBytes ,
5128
+ readerFunc : newBufferFunc ,
5129
+ },
5130
+ {
5131
+ name : "buffer, no length" ,
5132
+ readerFunc : newBufferFunc ,
5133
+ },
5134
+ {
5135
+ name : "buffer, length -1" ,
5136
+ contentLength : - 1 ,
5137
+ readerFunc : newBufferFunc ,
5138
+ },
5139
+ }
5140
+
5141
+ for _ , tc := range cases {
5142
+ t .Run (tc .name , func (t * testing.T ) {
5143
+ r , cleanup , err := tc .readerFunc ()
5144
+ if err != nil {
5145
+ t .Fatal (err )
5146
+ }
5147
+ defer cleanup ()
5148
+
5149
+ tConn := & testMockTCPConn {}
5150
+ trFunc := func (tr * Transport ) {
5151
+ tr .DialContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
5152
+ var d net.Dialer
5153
+ conn , err := d .DialContext (ctx , network , addr )
5154
+ if err != nil {
5155
+ return nil , err
5156
+ }
5157
+
5158
+ tcpConn , ok := conn .(* net.TCPConn )
5159
+ if ! ok {
5160
+ return nil , fmt .Errorf ("%s/%s does not provide a *net.TCPConn" , network , addr )
5161
+ }
5162
+
5163
+ tConn .TCPConn = tcpConn
5164
+ return tConn , nil
5165
+ }
5166
+ }
5167
+
5168
+ cst := newClientServerTest (
5169
+ t ,
5170
+ h1Mode ,
5171
+ HandlerFunc (func (w ResponseWriter , r * Request ) {
5172
+ io .Copy (ioutil .Discard , r .Body )
5173
+ r .Body .Close ()
5174
+ w .WriteHeader (200 )
5175
+ }),
5176
+ trFunc ,
5177
+ )
5178
+ defer cst .close ()
5179
+
5180
+ req , err := NewRequest ("PUT" , cst .ts .URL , r )
5181
+ if err != nil {
5182
+ t .Fatal (err )
5183
+ }
5184
+ req .ContentLength = tc .contentLength
5185
+ req .Header .Set ("Content-Type" , "application/octet-stream" )
5186
+ resp , err := cst .c .Do (req )
5187
+ if err != nil {
5188
+ t .Fatal (err )
5189
+ }
5190
+ defer resp .Body .Close ()
5191
+ if resp .StatusCode != 200 {
5192
+ t .Fatalf ("status code = %d; want 200" , resp .StatusCode )
5193
+ }
5194
+
5195
+ if ! tConn .ReadFromCalled && tc .expectedReadFrom {
5196
+ t .Fatalf ("did not call ReadFrom" )
5197
+ }
5198
+
5199
+ if tConn .ReadFromCalled && ! tc .expectedReadFrom {
5200
+ t .Fatalf ("ReadFrom was unexpectedly invoked" )
5201
+ }
5202
+ })
5203
+ }
5204
+ }
0 commit comments