diff --git a/Libraries/WebSocket/RCTSRWebSocket.h b/Libraries/WebSocket/RCTSRWebSocket.h index 1b17cffaf47c08..4784ec92c56a79 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.h +++ b/Libraries/WebSocket/RCTSRWebSocket.h @@ -80,6 +80,8 @@ extern NSString *const RCTSRHTTPResponseErrorKey; - (void)open; - (void)close; +- (void)closeSync; + - (void)closeWithCode:(NSInteger)code reason:(NSString *)reason; // Send a UTF8 String or Data. diff --git a/Libraries/WebSocket/RCTSRWebSocket.m b/Libraries/WebSocket/RCTSRWebSocket.m index 8ce6edc4dae216..a843ca6e4e7c76 100644 --- a/Libraries/WebSocket/RCTSRWebSocket.m +++ b/Libraries/WebSocket/RCTSRWebSocket.m @@ -554,48 +554,65 @@ - (void)close [self closeWithCode:RCTSRStatusCodeNormal reason:nil]; } +- (void)closeSync +{ + [self closeWithCode:RCTSRStatusCodeNormal reason:nil isBlocking:YES]; +} + - (void)closeWithCode:(NSInteger)code reason:(NSString *)reason { - assert(code); - dispatch_async(_workQueue, ^{ - if (self.readyState == RCTSR_CLOSING || self.readyState == RCTSR_CLOSED) { - return; - } + [self closeWithCode:code reason:reason isBlocking:NO]; +} - BOOL wasConnecting = self.readyState == RCTSR_CONNECTING; +- (void)closeWithCode:(NSInteger)code reason:(NSString *)reason isBlocking:(BOOL)isBlocking +{ + assert(code); + + void (^performClose)(void) = ^{ + if (self.readyState == RCTSR_CLOSING || self.readyState == RCTSR_CLOSED) { + return; + } - self.readyState = RCTSR_CLOSING; + BOOL wasConnecting = self.readyState == RCTSR_CONNECTING; - RCTSRLog(@"Closing with code %ld reason %@", code, reason); + self.readyState = RCTSR_CLOSING; - if (wasConnecting) { - [self _disconnect]; - return; - } + RCTSRLog(@"Closing with code %ld reason %@", code, reason); + + if (wasConnecting) { + [self _disconnect:isBlocking]; + return; + } - size_t maxMsgSize = [reason maximumLengthOfBytesUsingEncoding:NSUTF8StringEncoding]; - NSMutableData *mutablePayload = [[NSMutableData alloc] initWithLength:sizeof(uint16_t) + maxMsgSize]; - NSData *payload = mutablePayload; + size_t maxMsgSize = [reason maximumLengthOfBytesUsingEncoding:NSUTF8StringEncoding]; + NSMutableData *mutablePayload = [[NSMutableData alloc] initWithLength:sizeof(uint16_t) + maxMsgSize]; + NSData *payload = mutablePayload; - ((uint16_t *)mutablePayload.mutableBytes)[0] = NSSwapBigShortToHost(code); + ((uint16_t *)mutablePayload.mutableBytes)[0] = NSSwapBigShortToHost(code); - if (reason) { - NSRange remainingRange = {0}; + if (reason) { + NSRange remainingRange = {0}; - NSUInteger usedLength = 0; + NSUInteger usedLength = 0; - BOOL success __unused = [reason getBytes:(char *)mutablePayload.mutableBytes + sizeof(uint16_t) maxLength:payload.length - sizeof(uint16_t) usedLength:&usedLength encoding:NSUTF8StringEncoding options:NSStringEncodingConversionExternalRepresentation range:NSMakeRange(0, reason.length) remainingRange:&remainingRange]; + BOOL success __unused = [reason getBytes:(char *)mutablePayload.mutableBytes + sizeof(uint16_t) maxLength:payload.length - sizeof(uint16_t) usedLength:&usedLength encoding:NSUTF8StringEncoding options:NSStringEncodingConversionExternalRepresentation range:NSMakeRange(0, reason.length) remainingRange:&remainingRange]; - assert(success); - assert(remainingRange.length == 0); + assert(success); + assert(remainingRange.length == 0); - if (usedLength != maxMsgSize) { - payload = [payload subdataWithRange:NSMakeRange(0, usedLength + sizeof(uint16_t))]; + if (usedLength != maxMsgSize) { + payload = [payload subdataWithRange:NSMakeRange(0, usedLength + sizeof(uint16_t))]; + } } - } - [self _sendFrameWithOpcode:RCTSROpCodeConnectionClose data:payload]; - }); + [self _sendFrameWithOpcode:RCTSROpCodeConnectionClose data:payload skipWorkQueueAssertion:isBlocking]; + }; + + if (isBlocking) { + performClose(); + } else { + dispatch_async(_workQueue, performClose); + } } - (void)_closeWithProtocolError:(NSString *)message @@ -630,15 +647,22 @@ - (void)_failWithError:(NSError *)error }); } -- (void)_writeData:(NSData *)data +- (void)_writeData:(NSData *)data skipWorkQueueAssertion:(BOOL)skipWorkQueueAssertion { - [self assertOnWorkQueue]; + if (skipWorkQueueAssertion == NO){ + [self assertOnWorkQueue]; + } if (_closeWhenFinishedWriting) { return; } [_outputBuffer appendData:data]; - [self _pumpWriting]; + [self _pumpWriting:skipWorkQueueAssertion]; +} + +- (void)_writeData:(NSData *)data +{ + [self _writeData:data skipWorkQueueAssertion:NO]; } - (void)send:(id)data @@ -772,14 +796,22 @@ - (void)handleCloseWithData:(NSData *)data }); } -- (void)_disconnect +- (void)_disconnect:(BOOL)skipWorkQueueAssertion { - [self assertOnWorkQueue]; + if (skipWorkQueueAssertion == NO) { + [self assertOnWorkQueue]; + } + RCTSRLog(@"Trying to disconnect"); _closeWhenFinishedWriting = YES; [self _pumpWriting]; } +- (void)_disconnect +{ + [self _disconnect:NO]; +} + - (void)_handleFrameWithData:(NSData *)frameData opCode:(NSInteger)opcode { // Check that the current data is valid UTF8 @@ -1005,9 +1037,14 @@ - (void)_readFrameNew }); } -- (void)_pumpWriting +- (void)_pumpWriting { + [self _pumpWriting:NO]; +} +- (void)_pumpWriting:(BOOL)skipWorkQueueAssertion { - [self assertOnWorkQueue]; + if (skipWorkQueueAssertion == NO) { + [self assertOnWorkQueue]; + } NSUInteger dataLength = _outputBuffer.length; if (dataLength - _outputBufferOffset > 0 && _outputStream.hasSpaceAvailable) { @@ -1223,9 +1260,16 @@ - (void)_pumpScanner static const size_t RCTSRFrameHeaderOverhead = 32; -- (void)_sendFrameWithOpcode:(RCTSROpCode)opcode data:(NSData *)data +- (void)_sendFrameWithOpcode:(RCTSROpCode)opcode data:(NSData *)data { + [self _sendFrameWithOpcode:opcode data:data skipWorkQueueAssertion:NO]; +} + + +- (void)_sendFrameWithOpcode:(RCTSROpCode)opcode data:(NSData *)data skipWorkQueueAssertion:(BOOL)skipWorkQueueAssertion { - [self assertOnWorkQueue]; + if (skipWorkQueueAssertion == NO) { + [self assertOnWorkQueue]; + } if (nil == data) { return; @@ -1290,7 +1334,7 @@ - (void)_sendFrameWithOpcode:(RCTSROpCode)opcode data:(NSData *)data assert(frame_buffer_size <= [frame length]); frame.length = frame_buffer_size; - [self _writeData:frame]; + [self _writeData:frame skipWorkQueueAssertion:skipWorkQueueAssertion]; } - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode diff --git a/React/CoreModules/RCTWebSocketModule.h b/React/CoreModules/RCTWebSocketModule.h index c1a48d3d31b9a3..b658c7197cc018 100644 --- a/React/CoreModules/RCTWebSocketModule.h +++ b/React/CoreModules/RCTWebSocketModule.h @@ -24,6 +24,9 @@ NS_ASSUME_NONNULL_BEGIN - (void)sendData:(NSData *)data forSocketID:(nonnull NSNumber *)socketID; +// Closes all open websockets on the main thread +- (void)flush; + @end @interface RCTBridge (RCTWebSocketModule) diff --git a/React/CoreModules/RCTWebSocketModule.mm b/React/CoreModules/RCTWebSocketModule.mm index 32889ea7d4bddb..48fe67e01e0bef 100644 --- a/React/CoreModules/RCTWebSocketModule.mm +++ b/React/CoreModules/RCTWebSocketModule.mm @@ -51,6 +51,16 @@ - (NSArray *)supportedEvents return @[ @"websocketMessage", @"websocketOpen", @"websocketFailed", @"websocketClosed" ]; } + +- (void)flush +{ + _contentHandlers = nil; + for (RCTSRWebSocket *socket in _sockets.allValues) { + socket.delegate = nil; + [socket closeSync]; + } +} + - (void)invalidate { [super invalidate];