Skip to content

Commit 7e2bb22

Browse files
committed
feat(websocket): add websocket origin limit
1 parent aaafef9 commit 7e2bb22

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

bin/proxy/webSocketServer.js

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
const WebSocket = require('ws');
2+
const http = require('http');
3+
const url = require('url');
4+
const config = require('./config.js');
5+
6+
class webSocketServer extends WebSocket.Server {
7+
handleUpgrade(req, socket, head, cb) {
8+
const origin = req['headers']['origin'] || '';
9+
if (origin) {
10+
const obj = url.parse(origin);
11+
const host = req['headers']['host'];
12+
13+
if (obj.host !== host) {
14+
const allowWebSocketOriginHost = config.allowWebSocketOriginHost || [];
15+
if (allowWebSocketOriginHost.length === 0) {
16+
abortConnection(socket, 403);
17+
return;
18+
}
19+
20+
let i,
21+
len,
22+
v;
23+
for (i = 0, len = allowWebSocketOriginHost.length; i < len; i++) {
24+
v = allowWebSocketOriginHost[i];
25+
26+
if (typeof v === 'string') {
27+
if (v !== obj.host) {
28+
abortConnection(socket, 403);
29+
return;
30+
}
31+
} else if (typeof v === 'object') {
32+
if (!v.test || (v.test && !v.test(host))) {
33+
abortConnection(socket, 403);
34+
return;
35+
}
36+
}
37+
}
38+
}
39+
}
40+
41+
super.handleUpgrade(req, socket, head, cb);
42+
}
43+
}
44+
45+
module.exports = webSocketServer;
46+
47+
function abortConnection(socket, code, message) {
48+
if (socket.writable) {
49+
message = message || http.STATUS_CODES[code];
50+
socket.write(
51+
`HTTP/1.1 ${code} ${http.STATUS_CODES[code]}\r\n` +
52+
'Connection: close\r\n' +
53+
'Content-type: text/html\r\n' +
54+
`Content-Length: ${Buffer.byteLength(message)}\r\n` +
55+
'\r\n' + message
56+
);
57+
}
58+
59+
socket.destroy();
60+
}

bin/proxy/websocket.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
const WebSocket = require('ws');
12-
const WSServer = WebSocket.Server;
12+
const WSServer = require('./webSocketServer');
1313
const logger = require('logger');
1414
const domain = require('domain');
1515
const Context = require('runtime/Context');
@@ -46,7 +46,7 @@ function wsFiller(ws, req) {
4646
}
4747
};
4848

49-
ws.logKey = Math.random();
49+
ws.logKey = req.headers['sec-websocket-key'] || Math.random();
5050
}
5151

5252
function reportWebSocketLog(ws, isEnd) {

0 commit comments

Comments
 (0)