Skip to content

Commit 4c8b99e

Browse files
committed
Saving WIP rewrite
1 parent 4a5bf3d commit 4c8b99e

File tree

10 files changed

+432
-579
lines changed

10 files changed

+432
-579
lines changed

close.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ func (ce CloseError) bytes() []byte {
130130
}
131131

132132
func (ce CloseError) bytesErr() ([]byte, error) {
133-
// TODO move check into frame write
134-
// if len(ce.Reason) > wsframe.MaxControlFramePayload-2 {
135-
// return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason))
136-
// }
133+
const maxReason = maxControlPayload-2
134+
if len(ce.Reason) > maxReason {
135+
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxReason, ce.Reason, len(ce.Reason))
136+
}
137137
if !validWireCloseCode(ce.Code) {
138138
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
139139
}

conn.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ func (c *Conn) close(err error) {
116116
// closeErr.
117117
c.closer.Close()
118118

119+
// TODO acquire looks here in a new goroutine to ensure we don't get blocked.
119120
c.r.close()
120121
c.w.close()
121122
}
@@ -450,9 +451,6 @@ func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) err
450451
Reason: reason,
451452
}
452453

453-
// This function also will not wait for a close frame from the peer like the RFC
454-
// wants because that makes no sense and I don't think anyone actually follows that.
455-
// Definitely worth seeing what popular browsers do later.
456454
p := ce.bytes()
457455
return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake)
458456
}

internal/wsframe/mask.go renamed to frame.go

Lines changed: 181 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,189 @@
1-
package wsframe
1+
package websocket
22

33
import (
4+
"bufio"
45
"encoding/binary"
6+
"fmt"
7+
"math"
58
"math/bits"
69
)
710

8-
// Mask applies the WebSocket masking algorithm to p
11+
// opcode represents a WebSocket opcode.
12+
type opcode int
13+
14+
// List at https://tools.ietf.org/html/rfc6455#section-11.8.
15+
const (
16+
opContinuation opcode = iota
17+
opText
18+
opBinary
19+
// 3 - 7 are reserved for further non-control frames.
20+
_
21+
_
22+
_
23+
_
24+
_
25+
opClose
26+
opPing
27+
opPong
28+
// 11-16 are reserved for further control frames.
29+
)
30+
31+
// header represents a WebSocket frame header.
32+
// See https://tools.ietf.org/html/rfc6455#section-5.2.
33+
type header struct {
34+
fin bool
35+
rsv1 bool
36+
rsv2 bool
37+
rsv3 bool
38+
opcode opcode
39+
40+
payloadLength int64
41+
42+
masked bool
43+
maskKey uint32
44+
}
45+
46+
// ReadFrameHeader reads a header from the reader.
47+
// See https://tools.ietf.org/html/rfc6455#section-5.2.
48+
func ReadFrameHeader(r *bufio.Reader) (header, error) {
49+
h, err := readFrameHeader(r)
50+
if err != nil {
51+
return header{}, fmt.Errorf("failed to read frame header: %w", err)
52+
}
53+
return h, nil
54+
}
55+
56+
func readFrameHeader(r *bufio.Reader) (header, error) {
57+
b, err := r.ReadByte()
58+
if err != nil {
59+
return header{}, err
60+
}
61+
62+
var h header
63+
h.fin = b&(1<<7) != 0
64+
h.rsv1 = b&(1<<6) != 0
65+
h.rsv2 = b&(1<<5) != 0
66+
h.rsv3 = b&(1<<4) != 0
67+
68+
h.opcode = opcode(b & 0xf)
69+
70+
b, err = r.ReadByte()
71+
if err != nil {
72+
return header{}, err
73+
}
74+
75+
h.masked = b&(1<<7) != 0
76+
77+
payloadLength := b &^ (1 << 7)
78+
switch {
79+
case payloadLength < 126:
80+
h.payloadLength = int64(payloadLength)
81+
case payloadLength == 126:
82+
var pl uint16
83+
err = binary.Read(r, binary.BigEndian, &pl)
84+
h.payloadLength = int64(pl)
85+
case payloadLength == 127:
86+
err = binary.Read(r, binary.BigEndian, &h.payloadLength)
87+
}
88+
if err != nil {
89+
return header{}, err
90+
}
91+
92+
if h.masked {
93+
err = binary.Read(r, binary.LittleEndian, &h.maskKey)
94+
if err != nil {
95+
return header{}, err
96+
}
97+
}
98+
99+
return h, nil
100+
}
101+
102+
// maxControlPayload is the maximum length of a control frame payload.
103+
// See https://tools.ietf.org/html/rfc6455#section-5.5.
104+
const maxControlPayload = 125
105+
106+
// ParseClosePayload parses the bytes in p as a close frame payload,
107+
// returning the status code and reason.
108+
func ParseClosePayload(p []byte) (uint16, string, error) {
109+
if len(p) < 2 {
110+
return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
111+
}
112+
113+
return binary.BigEndian.Uint16(p), string(p[2:]), nil
114+
}
115+
116+
// WriteHeader writes the bytes of the header to w.
117+
// See https://tools.ietf.org/html/rfc6455#section-5.2
118+
func WriteHeader(h header, w *bufio.Writer) error {
119+
err := writeHeader(h, w)
120+
if err != nil {
121+
return fmt.Errorf("failed to write frame header: %w", err)
122+
}
123+
return nil
124+
}
125+
126+
func writeHeader(h header, w *bufio.Writer) error {
127+
var b byte
128+
if h.fin {
129+
b |= 1 << 7
130+
}
131+
if h.rsv1 {
132+
b |= 1 << 6
133+
}
134+
if h.rsv2 {
135+
b |= 1 << 5
136+
}
137+
if h.rsv3 {
138+
b |= 1 << 4
139+
}
140+
141+
b |= byte(h.opcode)
142+
143+
err := w.WriteByte(b)
144+
if err != nil {
145+
return err
146+
}
147+
148+
lengthByte := byte(0)
149+
if h.masked {
150+
lengthByte |= 1 << 7
151+
}
152+
153+
switch {
154+
case h.payloadLength > math.MaxUint16:
155+
lengthByte |= 127
156+
case h.payloadLength > 125:
157+
lengthByte |= 126
158+
case h.payloadLength >= 0:
159+
lengthByte |= byte(h.payloadLength)
160+
}
161+
err = w.WriteByte(lengthByte)
162+
if err != nil {
163+
return err
164+
}
165+
166+
switch {
167+
case h.payloadLength > math.MaxUint16:
168+
err = binary.Write(w, binary.BigEndian, h.payloadLength)
169+
case h.payloadLength > 125:
170+
err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength))
171+
}
172+
if err != nil {
173+
return err
174+
}
175+
176+
if h.masked {
177+
err = binary.Write(w, binary.LittleEndian, h.maskKey)
178+
if err != nil {
179+
return err
180+
}
181+
}
182+
183+
return nil
184+
}
185+
186+
// mask applies the WebSocket masking algorithm to p
9187
// with the given key.
10188
// See https://tools.ietf.org/html/rfc6455#section-5.3
11189
//
@@ -16,7 +194,7 @@ import (
16194
// to be in little endian.
17195
//
18196
// See https://github.com/golang/go/issues/31586
19-
func Mask(key uint32, b []byte) uint32 {
197+
func mask(key uint32, b []byte) uint32 {
20198
if len(b) >= 8 {
21199
key64 := uint64(key)<<32 | uint64(key)
22200

internal/wsframe/mask_test.go renamed to frame_test.go

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,115 @@
1-
package wsframe_test
1+
// +build !js
2+
3+
package websocket
24

35
import (
4-
"crypto/rand"
6+
"bufio"
7+
"bytes"
58
"encoding/binary"
6-
"github.com/gobwas/ws"
7-
"github.com/google/go-cmp/cmp"
89
"math/bits"
9-
"nhooyr.io/websocket/internal/wsframe"
10+
"nhooyr.io/websocket/internal/assert"
1011
"strconv"
1112
"testing"
13+
"time"
1214
_ "unsafe"
15+
16+
"github.com/gobwas/ws"
17+
"github.com/google/go-cmp/cmp"
18+
19+
_ "github.com/gorilla/websocket"
20+
"math/rand"
1321
)
1422

23+
func init() {
24+
rand.Seed(time.Now().UnixNano())
25+
}
26+
27+
func TestHeader(t *testing.T) {
28+
t.Parallel()
29+
30+
t.Run("readNegativeLength", func(t *testing.T) {
31+
t.Parallel()
32+
33+
r := bufio.NewReader(bytes.NewReader([]byte{}))
34+
_, err := ReadFrameHeader(r)
35+
assert.Error(t, err)
36+
})
37+
38+
t.Run("lengths", func(t *testing.T) {
39+
t.Parallel()
40+
41+
lengths := []int{
42+
124,
43+
125,
44+
126,
45+
127,
46+
47+
65534,
48+
65535,
49+
65536,
50+
65537,
51+
}
52+
53+
for _, n := range lengths {
54+
n := n
55+
t.Run(strconv.Itoa(n), func(t *testing.T) {
56+
t.Parallel()
57+
58+
testHeader(t, header{
59+
payloadLength: int64(n),
60+
})
61+
})
62+
}
63+
})
64+
65+
t.Run("fuzz", func(t *testing.T) {
66+
t.Parallel()
67+
68+
randBool := func() bool {
69+
return rand.Intn(1) == 0
70+
}
71+
72+
for i := 0; i < 10000; i++ {
73+
h := header{
74+
fin: randBool(),
75+
rsv1: randBool(),
76+
rsv2: randBool(),
77+
rsv3: randBool(),
78+
opcode: opcode(rand.Intn(16)),
79+
80+
masked: randBool(),
81+
maskKey: rand.Uint32(),
82+
payloadLength: rand.Int63(),
83+
}
84+
85+
testHeader(t, h)
86+
}
87+
})
88+
}
89+
90+
func testHeader(t *testing.T, h header) {
91+
b := &bytes.Buffer{}
92+
w := bufio.NewWriter(b)
93+
r := bufio.NewReader(b)
94+
95+
err := WriteHeader(h, w)
96+
assert.Success(t, err)
97+
err = w.Flush()
98+
assert.Success(t, err)
99+
100+
h2, err := ReadFrameHeader(r)
101+
assert.Success(t, err)
102+
103+
assert.Equalf(t, h, h2, "written and read headers differ")
104+
}
105+
15106
func Test_mask(t *testing.T) {
16107
t.Parallel()
17108

18109
key := []byte{0xa, 0xb, 0xc, 0xff}
19110
key32 := binary.LittleEndian.Uint32(key)
20111
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
21-
gotKey32 := wsframe.Mask(key32, p)
112+
gotKey32 := mask(key32, p)
22113

23114
if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) {
24115
t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p))
@@ -74,7 +165,7 @@ func Benchmark_mask(b *testing.B) {
74165
b.ResetTimer()
75166

76167
for i := 0; i < b.N; i++ {
77-
wsframe.Mask(key32, p)
168+
mask(key32, p)
78169
}
79170
},
80171
},

0 commit comments

Comments
 (0)