@@ -109,7 +109,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
109
109
110
110
if ! c .flate () {
111
111
defer c .msgWriter .mu .unlock ()
112
- return c .writeFrame (ctx , true , false , c .msgWriter .opcode , p )
112
+ return c .writeFrame (true , ctx , true , false , c .msgWriter .opcode , p )
113
113
}
114
114
115
115
n , err := mw .Write (p )
@@ -159,6 +159,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
159
159
defer func () {
160
160
if err != nil {
161
161
err = fmt .Errorf ("failed to write: %w" , err )
162
+ mw .writeMu .unlock ()
162
163
mw .c .close (err )
163
164
}
164
165
}()
@@ -179,7 +180,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
179
180
}
180
181
181
182
func (mw * msgWriter ) write (p []byte ) (int , error ) {
182
- n , err := mw .c .writeFrame (mw .ctx , false , mw .flate , mw .opcode , p )
183
+ n , err := mw .c .writeFrame (true , mw .ctx , false , mw .flate , mw .opcode , p )
183
184
if err != nil {
184
185
return n , fmt .Errorf ("failed to write data frame: %w" , err )
185
186
}
@@ -191,25 +192,25 @@ func (mw *msgWriter) write(p []byte) (int, error) {
191
192
func (mw * msgWriter ) Close () (err error ) {
192
193
defer errd .Wrap (& err , "failed to close writer" )
193
194
194
- if mw .closed {
195
- return errors .New ("writer already closed" )
196
- }
197
- mw .closed = true
198
-
199
195
err = mw .writeMu .lock (mw .ctx )
200
196
if err != nil {
201
197
return err
202
198
}
203
199
defer mw .writeMu .unlock ()
204
200
201
+ if mw .closed {
202
+ return errors .New ("writer already closed" )
203
+ }
204
+ mw .closed = true
205
+
205
206
if mw .flate {
206
207
err = mw .flateWriter .Flush ()
207
208
if err != nil {
208
209
return fmt .Errorf ("failed to flush flate: %w" , err )
209
210
}
210
211
}
211
212
212
- _ , err = mw .c .writeFrame (mw .ctx , true , mw .flate , mw .opcode , nil )
213
+ _ , err = mw .c .writeFrame (true , mw .ctx , true , mw .flate , mw .opcode , nil )
213
214
if err != nil {
214
215
return fmt .Errorf ("failed to write fin frame: %w" , err )
215
216
}
@@ -235,15 +236,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
235
236
ctx , cancel := context .WithTimeout (ctx , time .Second * 5 )
236
237
defer cancel ()
237
238
238
- _ , err := c .writeFrame (ctx , true , false , opcode , p )
239
+ _ , err := c .writeFrame (false , ctx , true , false , opcode , p )
239
240
if err != nil {
240
241
return fmt .Errorf ("failed to write control frame %v: %w" , opcode , err )
241
242
}
242
243
return nil
243
244
}
244
245
245
246
// frame handles all writes to the connection.
246
- func (c * Conn ) writeFrame (ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
247
+ func (c * Conn ) writeFrame (msgWriter bool , ctx context.Context , fin bool , flate bool , opcode opcode , p []byte ) (_ int , err error ) {
247
248
err = c .writeFrameMu .lock (ctx )
248
249
if err != nil {
249
250
return 0 , err
@@ -283,6 +284,10 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
283
284
err = ctx .Err ()
284
285
default :
285
286
}
287
+ c .writeFrameMu .unlock ()
288
+ if msgWriter {
289
+ c .msgWriter .writeMu .unlock ()
290
+ }
286
291
c .close (err )
287
292
err = fmt .Errorf ("failed to write frame: %w" , err )
288
293
}
0 commit comments