diff --git a/src/v1/internal/browser/browser-channel.js b/src/v1/internal/browser/browser-channel.js index b81038c48..e72e411ad 100644 --- a/src/v1/internal/browser/browser-channel.js +++ b/src/v1/internal/browser/browser-channel.js @@ -31,7 +31,11 @@ export default class WebSocketChannel { * @param {ChannelConfig} config - configuration for this channel. * @param {function(): string} protocolSupplier - function that detects protocol of the web page. Should only be used in tests. */ - constructor (config, protocolSupplier = detectWebPageProtocol) { + constructor ( + config, + protocolSupplier = detectWebPageProtocol, + socketFactory = createWebSocket + ) { this._open = true this._pending = [] this._error = null @@ -44,23 +48,24 @@ export default class WebSocketChannel { return } - this._ws = createWebSocket(scheme, config.address) + this._ws = socketFactory(scheme, config.address) this._ws.binaryType = 'arraybuffer' - let self = this + const self = this // All connection errors are not sent to the error handler // we must also check for dirty close calls this._ws.onclose = function (e) { if (!e.wasClean) { self._handleConnectionError() } + self._open = false } this._ws.onopen = function () { // Connected! Cancel the connection timeout self._clearConnectionTimeout() // Drain all pending messages - let pending = self._pending + const pending = self._pending self._pending = null for (let i = 0; i < pending.length; i++) { self.write(pending[i]) diff --git a/test/internal/browser/browser-channel.test.js b/test/internal/browser/browser-channel.test.js index dfca43768..46ba9ac74 100644 --- a/test/internal/browser/browser-channel.test.js +++ b/test/internal/browser/browser-channel.test.js @@ -24,6 +24,9 @@ import { setTimeoutMock } from '../timers-util' import { ENCRYPTION_OFF, ENCRYPTION_ON } from '../../../src/v1/internal/util' import ServerAddress from '../../../src/v1/internal/server-address' +const WS_OPEN = 1 +const WS_CLOSED = 3 + /* eslint-disable no-global-assign */ describe('WebSocketChannel', () => { let OriginalWebSocket @@ -236,4 +239,43 @@ describe('WebSocketChannel', () => { expect(channel).toBeDefined() expect(warnMessages.length).toEqual(1) } + + it('should set _open to false when connection closes', async () => { + const fakeSetTimeout = setTimeoutMock.install() + try { + // do not execute setTimeout callbacks + fakeSetTimeout.pause() + const address = ServerAddress.fromUrl('bolt://localhost:8989') + const driverConfig = { connectionTimeout: 4242 } + const channelConfig = new ChannelConfig( + address, + driverConfig, + SERVICE_UNAVAILABLE + ) + webSocketChannel = new WebSocketChannel( + channelConfig, + undefined, + createWebSocketFactory(WS_OPEN) + ) + webSocketChannel._ws.close() + expect(webSocketChannel._open).toBe(false) + } finally { + fakeSetTimeout.uninstall() + } + }) + + function createWebSocketFactory (readyState) { + const ws = {} + ws.readyState = readyState + ws.close = () => { + ws.readyState = WS_CLOSED + if (ws.onclose && typeof ws.onclose === 'function') { + ws.onclose({ wasClean: true }) + } + } + return url => { + ws.url = url + return ws + } + } })