Skip to content

Commit 4a5bf3d

Browse files
committed
Rewrite in progress
Too many improvements and changes to list :)
1 parent 992ed07 commit 4a5bf3d

19 files changed

+1134
-3327
lines changed

accept.go

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,15 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
6060
return c, nil
6161
}
6262

63-
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
63+
func (opts *AcceptOptions) ensure() *AcceptOptions {
6464
if opts == nil {
65-
opts = &AcceptOptions{}
65+
return &AcceptOptions{}
6666
}
67+
return nil
68+
}
69+
70+
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
71+
opts = opts.ensure()
6772

6873
err := verifyClientRequest(w, r)
6974
if err != nil {
@@ -126,21 +131,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
126131
return c, nil
127132
}
128133

129-
func authenticateOrigin(r *http.Request) error {
130-
origin := r.Header.Get("Origin")
131-
if origin == "" {
132-
return nil
133-
}
134-
u, err := url.Parse(origin)
135-
if err != nil {
136-
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
137-
}
138-
if !strings.EqualFold(u.Host, r.Host) {
139-
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
140-
}
141-
return nil
142-
}
143-
144134
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
145135
if !r.ProtoAtLeast(1, 1) {
146136
err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
@@ -181,15 +171,37 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
181171
return nil
182172
}
183173

174+
func authenticateOrigin(r *http.Request) error {
175+
origin := r.Header.Get("Origin")
176+
if origin == "" {
177+
return nil
178+
}
179+
u, err := url.Parse(origin)
180+
if err != nil {
181+
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
182+
}
183+
if !strings.EqualFold(u.Host, r.Host) {
184+
return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
185+
}
186+
return nil
187+
}
188+
184189
func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) {
185190
key := r.Header.Get("Sec-WebSocket-Key")
186191
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
187192
}
188193

189194
func selectSubprotocol(r *http.Request, subprotocols []string) string {
195+
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
196+
if len(cps) == 0 {
197+
return ""
198+
}
199+
190200
for _, sp := range subprotocols {
191-
if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
192-
return sp
201+
for _, cp := range cps {
202+
if strings.EqualFold(sp, cp) {
203+
return cp
204+
}
193205
}
194206
}
195207
return ""

autobahn_test.go

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package websocket_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io/ioutil"
8+
"net"
9+
"net/http"
10+
"net/http/httptest"
11+
"nhooyr.io/websocket"
12+
"os"
13+
"os/exec"
14+
"strconv"
15+
"strings"
16+
"testing"
17+
"time"
18+
)
19+
20+
func TestAutobahn(t *testing.T) {
21+
// This test contains the old autobahn test suite tests that use the
22+
// python binary. The approach is clunky and slow so new tests
23+
// have been written in pure Go in websocket_test.go.
24+
// These have been kept for correctness purposes and are occasionally ran.
25+
if os.Getenv("AUTOBAHN") == "" {
26+
t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite")
27+
}
28+
29+
t.Run("server", testServerAutobahnPython)
30+
t.Run("client", testClientAutobahnPython)
31+
}
32+
33+
// https://github.com/crossbario/autobahn-python/tree/master/wstest
34+
func testServerAutobahnPython(t *testing.T) {
35+
t.Parallel()
36+
37+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38+
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
39+
Subprotocols: []string{"echo"},
40+
})
41+
if err != nil {
42+
t.Logf("server handshake failed: %+v", err)
43+
return
44+
}
45+
echoLoop(r.Context(), c)
46+
}))
47+
defer s.Close()
48+
49+
spec := map[string]interface{}{
50+
"outdir": "ci/out/wstestServerReports",
51+
"servers": []interface{}{
52+
map[string]interface{}{
53+
"agent": "main",
54+
"url": strings.Replace(s.URL, "http", "ws", 1),
55+
},
56+
},
57+
"cases": []string{"*"},
58+
// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
59+
// more performance overhead. 7.5.1 is the same.
60+
"exclude-cases": []string{"6.*", "7.5.1"},
61+
}
62+
specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
63+
if err != nil {
64+
t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
65+
}
66+
defer specFile.Close()
67+
68+
e := json.NewEncoder(specFile)
69+
e.SetIndent("", "\t")
70+
err = e.Encode(spec)
71+
if err != nil {
72+
t.Fatalf("failed to write spec: %v", err)
73+
}
74+
75+
err = specFile.Close()
76+
if err != nil {
77+
t.Fatalf("failed to close file: %v", err)
78+
}
79+
80+
ctx := context.Background()
81+
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
82+
defer cancel()
83+
84+
args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
85+
wstest := exec.CommandContext(ctx, "wstest", args...)
86+
out, err := wstest.CombinedOutput()
87+
if err != nil {
88+
t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
89+
}
90+
91+
checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
92+
}
93+
94+
func unusedListenAddr() (string, error) {
95+
l, err := net.Listen("tcp", "localhost:0")
96+
if err != nil {
97+
return "", err
98+
}
99+
l.Close()
100+
return l.Addr().String(), nil
101+
}
102+
103+
// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
104+
func testClientAutobahnPython(t *testing.T) {
105+
t.Parallel()
106+
107+
if os.Getenv("AUTOBAHN_PYTHON") == "" {
108+
t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite")
109+
}
110+
111+
serverAddr, err := unusedListenAddr()
112+
if err != nil {
113+
t.Fatalf("failed to get unused listen addr for wstest: %v", err)
114+
}
115+
116+
wsServerURL := "ws://" + serverAddr
117+
118+
spec := map[string]interface{}{
119+
"url": wsServerURL,
120+
"outdir": "ci/out/wstestClientReports",
121+
"cases": []string{"*"},
122+
// See TestAutobahnServer for the reasons why we exclude these.
123+
"exclude-cases": []string{"6.*", "7.5.1"},
124+
}
125+
specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
126+
if err != nil {
127+
t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
128+
}
129+
defer specFile.Close()
130+
131+
e := json.NewEncoder(specFile)
132+
e.SetIndent("", "\t")
133+
err = e.Encode(spec)
134+
if err != nil {
135+
t.Fatalf("failed to write spec: %v", err)
136+
}
137+
138+
err = specFile.Close()
139+
if err != nil {
140+
t.Fatalf("failed to close file: %v", err)
141+
}
142+
143+
ctx := context.Background()
144+
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
145+
defer cancel()
146+
147+
args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(),
148+
// Disables some server that runs as part of fuzzingserver mode.
149+
// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
150+
"--webport=0",
151+
}
152+
wstest := exec.CommandContext(ctx, "wstest", args...)
153+
err = wstest.Start()
154+
if err != nil {
155+
t.Fatal(err)
156+
}
157+
defer func() {
158+
err := wstest.Process.Kill()
159+
if err != nil {
160+
t.Error(err)
161+
}
162+
}()
163+
164+
// Let it come up.
165+
time.Sleep(time.Second * 5)
166+
167+
var cases int
168+
func() {
169+
c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil)
170+
if err != nil {
171+
t.Fatal(err)
172+
}
173+
defer c.Close(websocket.StatusInternalError, "")
174+
175+
_, r, err := c.Reader(ctx)
176+
if err != nil {
177+
t.Fatal(err)
178+
}
179+
b, err := ioutil.ReadAll(r)
180+
if err != nil {
181+
t.Fatal(err)
182+
}
183+
cases, err = strconv.Atoi(string(b))
184+
if err != nil {
185+
t.Fatal(err)
186+
}
187+
188+
c.Close(websocket.StatusNormalClosure, "")
189+
}()
190+
191+
for i := 1; i <= cases; i++ {
192+
func() {
193+
ctx, cancel := context.WithTimeout(ctx, time.Second*45)
194+
defer cancel()
195+
196+
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil)
197+
if err != nil {
198+
t.Fatal(err)
199+
}
200+
echoLoop(ctx, c)
201+
}()
202+
}
203+
204+
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
c.Close(websocket.StatusNormalClosure, "")
209+
210+
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
211+
}
212+
213+
func checkWSTestIndex(t *testing.T, path string) {
214+
wstestOut, err := ioutil.ReadFile(path)
215+
if err != nil {
216+
t.Fatalf("failed to read index.json: %v", err)
217+
}
218+
219+
var indexJSON map[string]map[string]struct {
220+
Behavior string `json:"behavior"`
221+
BehaviorClose string `json:"behaviorClose"`
222+
}
223+
err = json.Unmarshal(wstestOut, &indexJSON)
224+
if err != nil {
225+
t.Fatalf("failed to unmarshal index.json: %v", err)
226+
}
227+
228+
var failed bool
229+
for _, tests := range indexJSON {
230+
for test, result := range tests {
231+
switch result.Behavior {
232+
case "OK", "NON-STRICT", "INFORMATIONAL":
233+
default:
234+
failed = true
235+
t.Errorf("test %v failed", test)
236+
}
237+
switch result.BehaviorClose {
238+
case "OK", "INFORMATIONAL":
239+
default:
240+
failed = true
241+
t.Errorf("bad close behaviour for test %v", test)
242+
}
243+
}
244+
}
245+
246+
if failed {
247+
path = strings.Replace(path, ".json", ".html", 1)
248+
if os.Getenv("CI") == "" {
249+
t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path)
250+
}
251+
}
252+
}

0 commit comments

Comments
 (0)