diff --git a/Package.swift b/Package.swift index eaf1e1b..cadb0ac 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.5 +// swift-tools-version:5.7 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription diff --git a/README.md b/README.md index 7220b99..f21455b 100644 --- a/README.md +++ b/README.md @@ -96,8 +96,67 @@ allUserId99Changes.subscribe() allUserId99Changes.unsubscribe() allUserId99Changes.off(.all) ``` +### Broadcast +* Listen for `broadcast` messages: +```swift +let channel = client.channel(.table("channel_id", schema: "someChannel"), options: .init(presenceKey: "user_uuid")) +channel.on(.broadcast) { message in + let payload = message.payload["payload"] + let event = message.payload["event"] + let type = message.payload["type"] + print(type, event, payload) +} + +channel.join() +``` + +* Send `broadcast` messages: + +```swift +let channel = client.channel(.table("channel_id", schema: "someChannel"), options: .init(presenceKey: "user_uuid")) +channel.join() + +channel.broadcast(event: "my_event", payload: ["hello": "world"]) +``` +### Presence + +Presence can be used to share state between clients. + +* Listen to presence `sync` events to track state changes: + +```swift +let channel = client.channel(.table("channel_id", schema: "someChannel"), options: .init(presenceKey: "user_uuid")) +let presence = Presence(channel: channel) + +presence.onSync { + print("presence sync", presence?.state, presence?.list()) +} + +channel.join() +// ... +``` + +* Track presence state changes: + +```swift +let channel = client.channel(.table("channel_id", schema: "someChannel"), options: .init(presenceKey: "user_uuid")) +channel.join() + +channel.track(payload: [ + ["hello": "world] +]) +``` + +* Remove tracked presence state changes: + +```swift +let channel = client.channel(.table("channel_id", schema: "someChannel"), options: .init(presenceKey: "user_uuid")) +channel.join() + +channel.untrack() +``` ## Credits - https://github.com/supabase/realtime-js diff --git a/Sources/Realtime/Channel.swift b/Sources/Realtime/Channel.swift index d7b56e1..c22a521 100644 --- a/Sources/Realtime/Channel.swift +++ b/Sources/Realtime/Channel.swift @@ -55,11 +55,11 @@ struct Binding { /// public class Channel { - /// The topic of the Channel. e.g. `.table("rooms", "friends")` + /// The topic of the Channel. e.g. "rooms:friends" public let topic: ChannelTopic /// The params sent when joining the channel - public var params: [String: Any] { + public var params: Payload { didSet { self.joinPush.payload = params } } @@ -70,7 +70,7 @@ public class Channel { var state: ChannelState /// Collection of event bindings - var bindingsDel: [Binding] + var syncBindingsDel: SynchronizedArray /// Tracks event binding ref counters var bindingRef: Int @@ -93,17 +93,25 @@ public class Channel { /// Refs of stateChange hooks var stateChangeRefs: [String] + /// Initialize a Channel + /// - parameter topic: Topic of the Channel + /// - parameter options: Optional. Options to configure channel broadcast and presence. Leave nil for postgres channel. + /// - parameter socket: Socket that the channel is a part of + convenience init(topic: ChannelTopic, options: ChannelOptions? = nil, socket: RealtimeClient) { + self.init(topic: topic, params: options?.params ?? [:], socket: socket) + } + /// Initialize a Channel /// /// - parameter topic: Topic of the Channel /// - parameter params: Optional. Parameters to send when joining. /// - parameter socket: Socket that the channel is a part of - init(topic: ChannelTopic, params: [String: Any] = [:], socket: RealtimeClient) { + init(topic: ChannelTopic, params: [String: Any], socket: RealtimeClient) { state = ChannelState.closed self.topic = topic self.params = params self.socket = socket - bindingsDel = [] + syncBindingsDel = SynchronizedArray() bindingRef = 0 timeout = socket.timeout joinedOnce = false @@ -127,7 +135,8 @@ public class Channel { to: self, callback: { (self, _) in self.rejoinTimer.reset() - }) + } + ) if let ref = onErrorRef { stateChangeRefs.append(ref) } let onOpenRef = self.socket?.delegateOnOpen( @@ -135,7 +144,8 @@ public class Channel { callback: { (self) in self.rejoinTimer.reset() if self.isErrored { self.rejoin() } - }) + } + ) if let ref = onOpenRef { stateChangeRefs.append(ref) } // Setup Push Event to be sent when joining @@ -143,10 +153,11 @@ public class Channel { channel: self, event: ChannelEvent.join, payload: self.params, - timeout: timeout) + timeout: timeout + ) /// Handle when a response is received after join() - joinPush.delegateReceive("ok", to: self) { (self, _) in + joinPush.delegateReceive(.ok, to: self) { (self, _) in // Mark the Channel as joined self.state = ChannelState.joined @@ -159,22 +170,24 @@ public class Channel { } // Perform if Channel errors while attempting to joi - joinPush.delegateReceive("error", to: self) { (self, _) in + joinPush.delegateReceive(.error, to: self) { (self, _) in self.state = .errored if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } } // Handle when the join push times out when sending after join() - joinPush.delegateReceive("timeout", to: self) { (self, _) in + joinPush.delegateReceive(.timeout, to: self) { (self, _) in // log that the channel timed out self.socket?.logItems( - "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s") + "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" + ) // Send a Push to the server to leave the channel let leavePush = Push( channel: self, event: ChannelEvent.leave, - timeout: self.timeout) + timeout: self.timeout + ) leavePush.send() // Mark the Channel as in an error and attempt to rejoin if socket is connected @@ -191,7 +204,8 @@ public class Channel { // Log that the channel was left self.socket?.logItems( - "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")") + "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" + ) // Mark the channel as closed and remove it from the socket self.state = ChannelState.closed @@ -227,9 +241,10 @@ public class Channel { // Trigger bindings self.trigger( event: ChannelEvent.channelReply(message.ref), - payload: message.payload, + payload: message.rawPayload, ref: message.ref, - joinRef: message.joinRef) + joinRef: message.joinRef + ) } } @@ -251,7 +266,7 @@ public class Channel { /// - parameter timeout: Optional. Defaults to Channel's timeout /// - return: Push event @discardableResult - public func subscribe(timeout: TimeInterval? = nil) -> Push { + public func join(timeout: TimeInterval? = nil) -> Push { guard !joinedOnce else { fatalError( "tried to join multiple times. 'join' " @@ -352,13 +367,13 @@ public class Channel { /// Example: /// /// let channel = socket.channel("topic") - /// let ref1 = channel.on(.all) { [weak self] (message) in + /// let ref1 = channel.on("event") { [weak self] (message) in /// self?.print("do stuff") /// } - /// let ref2 = channel.on(.all) { [weak self] (message) in + /// let ref2 = channel.on("event") { [weak self] (message) in /// self?.print("do other stuff") /// } - /// channel.off(.all, ref1) + /// channel.off("event", ref1) /// /// Since unsubscription of ref1, "do stuff" won't print, but "do other /// stuff" will keep on printing on the "event" @@ -383,16 +398,16 @@ public class Channel { /// Example: /// /// let channel = socket.channel("topic") - /// let ref1 = channel.delegateOn(.all, to: self) { (self, message) in + /// let ref1 = channel.delegateOn("event", to: self) { (self, message) in /// self?.print("do stuff") /// } - /// let ref2 = channel.delegateOn(.all, to: self) { (self, message) in + /// let ref2 = channel.delegateOn("event", to: self) { (self, message) in /// self?.print("do other stuff") /// } - /// channel.off(.all, ref1) + /// channel.off("event", ref1) /// /// Since unsubscription of ref1, "do stuff" won't print, but "do other - /// stuff" will keep on printing on all "event" (*). + /// stuff" will keep on printing on the "event" /// /// - parameter event: Event to receive /// - parameter owner: Class registering the callback. Usually `self` @@ -416,7 +431,7 @@ public class Channel { let ref = bindingRef bindingRef = ref + 1 - bindingsDel.append(Binding(event: event, ref: ref, callback: delegated)) + syncBindingsDel.append(Binding(event: event, ref: ref, callback: delegated)) return ref } @@ -427,20 +442,20 @@ public class Channel { /// Example: /// /// let channel = socket.channel("topic") - /// let ref1 = channel.on(.insert) { _ in print("ref1 event" } - /// let ref2 = channel.on(.insert) { _ in print("ref2 event" } - /// let ref3 = channel.on(.update) { _ in print("ref3 other" } - /// let ref4 = channel.on(.update) { _ in print("ref4 other" } - /// channel.off(.insert, ref1) - /// channel.off(.update) + /// let ref1 = channel.on("event") { _ in print("ref1 event" } + /// let ref2 = channel.on("event") { _ in print("ref2 event" } + /// let ref3 = channel.on("other_event") { _ in print("ref3 other" } + /// let ref4 = channel.on("other_event") { _ in print("ref4 other" } + /// channel.off("event", ref1) + /// channel.off("other_event") /// /// After this, only "ref2 event" will be printed if the channel receives - /// "insert" and nothing is printed if the channel receives "update". + /// "event" and nothing is printed if the channel receives "other_event". /// /// - parameter event: Event to unsubscribe from /// - paramter ref: Ref counter returned when subscribing. Can be omitted public func off(_ event: ChannelEvent, ref: Int? = nil) { - bindingsDel.removeAll { (bind) -> Bool in + syncBindingsDel.removeAll { bind -> Bool in bind.event == event && (ref == nil || ref == bind.ref) } } @@ -450,7 +465,7 @@ public class Channel { /// Example: /// /// channel - /// .push(.update, payload: ["message": "hello") + /// .push("event", payload: ["message": "hello") /// .receive("ok") { _ in { print("message sent") } /// /// - parameter event: Event to push @@ -459,7 +474,7 @@ public class Channel { @discardableResult public func push( _ event: ChannelEvent, - payload: [String: Any], + payload: Payload, timeout: TimeInterval = Defaults.timeoutInterval ) -> Push { guard joinedOnce else { @@ -472,7 +487,8 @@ public class Channel { channel: self, event: event, payload: payload, - timeout: timeout) + timeout: timeout + ) if canPush { pushEvent.send() } else { @@ -495,12 +511,12 @@ public class Channel { /// /// Example: //// - /// channel.unsubscribe().receive("ok") { _ in { print("left") } + /// channel.leave().receive("ok") { _ in { print("left") } /// /// - parameter timeout: Optional timeout /// - return: Push that can add receive hooks @discardableResult - public func unsubscribe(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { + public func leave(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { // If attempting a rejoin during a leave, then reset, cancelling the rejoin rejoinTimer.reset() @@ -520,17 +536,18 @@ public class Channel { let leavePush = Push( channel: self, event: ChannelEvent.leave, - timeout: timeout) + timeout: timeout + ) // Perform the same behavior if successfully left the channel // or if sending the event timed out leavePush - .receive("ok", delegated: onCloseDelegate) - .receive("timeout", delegated: onCloseDelegate) + .receive(.ok, delegated: onCloseDelegate) + .receive(.timeout, delegated: onCloseDelegate) leavePush.send() // If the Channel cannot send push events, trigger a success locally - if !canPush { leavePush.trigger("ok", payload: [:]) } + if !canPush { leavePush.trigger(.ok, payload: [:]) } // Return the push so it can be bound to return leavePush @@ -560,12 +577,13 @@ public class Channel { guard let safeJoinRef = message.joinRef, safeJoinRef != joinRef, - ChannelEvent.isLifecyleEvent(message.event) + message.event.isLifecyleEvent else { return true } socket?.logItems( - "channel", "dropping outdated message", message.topic, message.event, message.payload, - safeJoinRef) + "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, + safeJoinRef + ) return false } @@ -594,13 +612,13 @@ public class Channel { func trigger(_ message: Message) { let handledMessage = onMessage(message) - bindingsDel + syncBindingsDel .filter { $0.event == message.event } .forEach { $0.callback.call(handledMessage) } } /// Triggers an event to the correct event bindings created by - //// `channel.on(event)`. + //// `channel.on("event")`. /// /// - parameter event: Event to trigger /// - parameter payload: Payload of the event @@ -608,7 +626,7 @@ public class Channel { /// - parameter joinRef: Ref of the join event. Defaults to nil func trigger( event: ChannelEvent, - payload: [String: Any] = [:], + payload: Payload = [:], ref: String = "", joinRef: String? = nil ) { @@ -617,7 +635,8 @@ public class Channel { topic: topic, event: event, payload: payload, - joinRef: joinRef ?? self.joinRef) + joinRef: joinRef ?? self.joinRef + ) trigger(message) } @@ -664,3 +683,140 @@ extension Channel { return state == .leaving } } +// ---------------------------------------------------------------------- + +// MARK: - Codable Payload + +// ---------------------------------------------------------------------- + +extension Payload { + + /// Initializes a payload from a given value + /// - parameter value: The value to encode + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + init(_ value: T, encoder: JSONEncoder = Defaults.encoder) throws { + let data = try encoder.encode(value) + self = try JSONSerialization.jsonObject(with: data, options: .allowFragments) as! Payload + } + + /// Decodes the payload to a given type + /// - parameter type: The type to decode to + /// - parameter decoder: The decoder to use to decode the payload + /// - returns: The decoded payload + /// - throws: Throws an error if the payload cannot be decoded + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> T { + let data = try JSONSerialization.data(withJSONObject: self) + return try decoder.decode(type, from: data) + } + +} + +// ---------------------------------------------------------------------- + +// MARK: - Broadcast API + +// ---------------------------------------------------------------------- + +/// Represents the payload of a broadcast message +public struct BroadcastPayload { + public let type: String + public let event: String + public let payload: Payload +} + +extension Channel { + /// Broadcasts the payload to all other members of the channel + /// - parameter event: The event to broadcast + /// - parameter payload: The payload to broadcast + @discardableResult + public func broadcast(event: String, payload: Payload) -> Push { + self.push( + .broadcast, + payload: [ + "type": "broadcast", + "event": event, + "payload": payload, + ]) + } + + /// Broadcasts the encodable payload to all other members of the channel + /// - parameter event: The event to broadcast + /// - parameter payload: The payload to broadcast + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + @discardableResult + public func broadcast(event: String, payload: Encodable, encoder: JSONEncoder = Defaults.encoder) + throws -> Push + { + self.broadcast(event: event, payload: try Payload(payload)) + } + + /// Subscribes to broadcast events. Does not handle retain cycles. + /// + /// Example: + /// + /// let ref = channel.onBroadcast { [weak self] (message,broadcast) in + /// print(broadcast.event, broadcast.payload) + /// } + /// channel.off(.broadcast, ref1) + /// + /// Subscription returns a ref counter, which can be used later to + /// unsubscribe the exact event listener + /// - parameter callback: Called with the broadcast payload + /// - returns: Ref counter of the subscription. See `func off()` + @discardableResult + public func onBroadcast(callback: @escaping (Message, BroadcastPayload) -> Void) -> Int { + self.on( + .broadcast, + callback: { message in + let payload = BroadcastPayload( + type: message.payload["type"] as! String, event: message.payload["event"] as! String, + payload: message.payload["payload"] as! Payload) + callback(message, payload) + }) + } + +} +// ---------------------------------------------------------------------- + +// MARK: - Presence API + +// ---------------------------------------------------------------------- + +extension Channel { + /// Share presence state, available to all channel members via sync + /// - parameter payload: The payload to broadcast + @discardableResult + public func track(payload: Payload) -> Push { + self.push( + .presence, + payload: [ + "type": "presence", + "event": "track", + "payload": payload, + ]) + } + + /// Share presence state, available to all channel members via sync + /// - parameter payload: The payload to broadcast + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + @discardableResult + public func track(payload: Encodable, encoder: JSONEncoder = Defaults.encoder) throws -> Push { + self.track(payload: try Payload(payload)) + } + + /// Remove presence state for given channel + @discardableResult + public func untrack() -> Push { + self.push( + .presence, + payload: [ + "type": "presence", + "event": "untrack", + ]) + } +} diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift index 4383785..bfc4e55 100644 --- a/Sources/Realtime/Defaults.swift +++ b/Sources/Realtime/Defaults.swift @@ -20,7 +20,7 @@ import Foundation -/// A collection of default values and behaviors used accross the Client +/// A collection of default values and behaviors used across the Client public enum Defaults { /// Default timeout when sending messages public static let timeoutInterval: TimeInterval = 10.0 @@ -28,6 +28,9 @@ public enum Defaults { /// Default interval to send heartbeats on public static let heartbeatInterval: TimeInterval = 30.0 + /// Default maximum amount of time which the system may delay heartbeat events in order to minimize power usage + public static let heartbeatLeeway: DispatchTimeInterval = .milliseconds(10) + /// Default reconnect algorithm for the socket public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in tries > 9 ? 5.0 : [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0, 2.0][tries - 1] @@ -38,26 +41,40 @@ public enum Defaults { tries > 3 ? 10 : [1, 2, 5][tries - 1] } + public static let vsn = "2.0.0" + + /// Default encoder + public static let encoder: JSONEncoder = JSONEncoder() + /// Default encode function, utilizing JSONSerialization.data - public static let encode: ([String: Any]) -> Data = { json in - try! JSONSerialization + public static let encode: (Any) -> Data = { json in + assert(JSONSerialization.isValidJSONObject(json), "Invalid JSON object") + return + try! JSONSerialization .data( withJSONObject: json, - options: JSONSerialization.WritingOptions()) + options: JSONSerialization.WritingOptions() + ) } + /// Default decoder + public static let decoder: JSONDecoder = JSONDecoder() + /// Default decode function, utilizing JSONSerialization.jsonObject - public static let decode: (Data) -> [String: Any]? = { data in + public static let decode: (Data) -> Any? = { data in guard let json = try? JSONSerialization .jsonObject( with: data, - options: JSONSerialization.ReadingOptions()) - as? [String: Any] + options: JSONSerialization.ReadingOptions() + ) else { return nil } return json } + + public static let heartbeatQueue: DispatchQueue = .init( + label: "com.phoenix.socket.heartbeat") } /// Represents the multiple states that a Channel can be in @@ -74,6 +91,11 @@ public enum ChannelState: String { /// a channel regarding a Channel's lifecycle or /// that can be registered to be notified of. public enum ChannelEvent: RawRepresentable { + public enum Presence: String { + case state + case diff + } + case heartbeat case join case leave @@ -88,6 +110,12 @@ public enum ChannelEvent: RawRepresentable { case channelReply(String) + case broadcast + + case presence + case presenceState + case presenceDiff + public var rawValue: String { switch self { case .heartbeat: return "heartbeat" @@ -102,7 +130,13 @@ public enum ChannelEvent: RawRepresentable { case .update: return "update" case .delete: return "delete" - case .channelReply(let reference): return "chan_reply_\(reference)" + case let .channelReply(reference): return "chan_reply_\(reference)" + + case .broadcast: return "broadcast" + + case .presence: return "presence" + case .presenceState: return "presence_state" + case .presenceDiff: return "presence_diff" } } @@ -118,14 +152,18 @@ public enum ChannelEvent: RawRepresentable { case "insert": self = .insert case "update": self = .update case "delete": self = .delete + case "broadcast": self = .broadcast + case "presence": self = .presence + case "presence_state": self = .presenceState + case "presence_diff": self = .presenceDiff default: return nil } } - static func isLifecyleEvent(_ event: ChannelEvent) -> Bool { - switch event { + var isLifecyleEvent: Bool { + switch self { case .join, .leave, .reply, .error, .close: return true - case .heartbeat, .all, .insert, .update, .delete, .channelReply: return false + default: return false } } } @@ -142,9 +180,9 @@ public enum ChannelTopic: RawRepresentable, Equatable { public var rawValue: String { switch self { case .all: return "realtime:*" - case .schema(let name): return "realtime:\(name)" - case .table(let tableName, let schema): return "realtime:\(schema):\(tableName)" - case .column(let columnName, let value, let table, let schema): + case let .schema(name): return "realtime:\(name)" + case let .table(tableName, schema): return "realtime:\(schema):\(tableName)" + case let .column(columnName, value, table, schema): return "realtime:\(schema):\(table):\(columnName)=eq.\(value)" case .heartbeat: return "phoenix" } @@ -169,7 +207,8 @@ public enum ChannelTopic: RawRepresentable, Equatable { { self = .column( String(condition[0]), value: String(condition[1].dropFirst(3)), table: String(parts[1]), - schema: String(parts[0])) + schema: String(parts[0]) + ) } else { return nil } @@ -179,3 +218,44 @@ public enum ChannelTopic: RawRepresentable, Equatable { } } } + +/// Represents the broadcast and presence options for a channel. +public struct ChannelOptions { + /// Used to track presence payload across clients. Must be unique per client. If `nil`, the server will generate one. + var presenceKey: String? + /// Enables the client to receieve their own`broadcast` messages + var broadcastSelf: Bool + /// Instructs the server to acknoledge the client's `broadcast` messages + var broadcastAcknowledge: Bool + + public init( + presenceKey: String? = nil, broadcastSelf: Bool = false, broadcastAcknowledge: Bool = false + ) { + self.presenceKey = presenceKey + self.broadcastSelf = broadcastSelf + self.broadcastAcknowledge = broadcastAcknowledge + } + + /// Parameters used to configure the channel + var params: [String: [String: Any]] { + [ + "config": [ + "presence": [ + "key": presenceKey ?? "" + ], + "broadcast": [ + "ack": broadcastAcknowledge, + "self": broadcastSelf, + ], + ] + ] + } + +} + +/// Represents the different status of a push +public enum PushStatus: String { + case ok + case error + case timeout +} diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift index 4983f32..d8de6c5 100644 --- a/Sources/Realtime/HeartbeatTimer.swift +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -20,80 +20,114 @@ import Foundation -class HeartbeatTimer: Equatable { +/// Heartbeat Timer class which manages the lifecycle of the underlying +/// timer which triggers when a heartbeat should be fired. This heartbeat +/// runs on it's own Queue so that it does not interfere with the main +/// queue but guarantees thread safety. +class HeartbeatTimer { + // ---------------------------------------------------------------------- + + // MARK: - Dependencies + + // ---------------------------------------------------------------------- + // The interval to wait before firing the Timer let timeInterval: TimeInterval - let dispatchQueue: DispatchQueue - let id: String = UUID().uuidString - init(timeInterval: TimeInterval, dispatchQueue: DispatchQueue) { - self.timeInterval = timeInterval - self.dispatchQueue = dispatchQueue - } - private lazy var timer: DispatchSourceTimer = { - let t = DispatchSource.makeTimerSource(flags: [], queue: self.dispatchQueue) - t.schedule(deadline: .now() + self.timeInterval, repeating: self.timeInterval) - t.setEventHandler(handler: { [weak self] in - self?.eventHandler?() - }) - return t - }() + /// The maximum amount of time which the system may delay the delivery of the timer events + let leeway: DispatchTimeInterval - var isValid: Bool { - return state == .resumed - } + // The DispatchQueue to schedule the timers on + let queue: DispatchQueue - private var eventHandler: (() -> Void)? + // UUID which specifies the Timer instance. Verifies that timers are different + let uuid: String = UUID().uuidString - private enum State { - case suspended - case resumed - } + // ---------------------------------------------------------------------- + + // MARK: - Properties + + // ---------------------------------------------------------------------- + // The underlying, cancelable, resettable, timer. + private var temporaryTimer: DispatchSourceTimer? + // The event handler that is called by the timer when it fires. + private var temporaryEventHandler: (() -> Void)? - private var state: State = .suspended + /** + Create a new HeartbeatTimer - func startTimerWithEvent(eventHandler: (() -> Void)?) { - self.eventHandler = eventHandler - resume() + - Parameters: + - timeInterval: Interval to fire the timer. Repeats + - queue: Queue to schedule the timer on + - leeway: The maximum amount of time which the system may delay the delivery of the timer events + */ + init( + timeInterval: TimeInterval, queue: DispatchQueue = Defaults.heartbeatQueue, + leeway: DispatchTimeInterval = Defaults.heartbeatLeeway + ) { + self.timeInterval = timeInterval + self.queue = queue + self.leeway = leeway } - func stopTimer() { - timer.setEventHandler {} - eventHandler = nil - suspend() + /** + Create a new HeartbeatTimer + + - Parameter timeInterval: Interval to fire the timer. Repeats + */ + convenience init(timeInterval: TimeInterval) { + self.init(timeInterval: timeInterval, queue: Defaults.heartbeatQueue) } - private func resume() { - if state == .resumed { - return + func start(eventHandler: @escaping () -> Void) { + queue.sync { + // Create a new DispatchSourceTimer, passing the event handler + let timer = DispatchSource.makeTimerSource(flags: [], queue: queue) + timer.setEventHandler(handler: eventHandler) + + // Schedule the timer to first fire in `timeInterval` and then + // repeat every `timeInterval` + timer.schedule( + deadline: DispatchTime.now() + self.timeInterval, + repeating: self.timeInterval, + leeway: self.leeway + ) + + // Start the timer + timer.resume() + self.temporaryEventHandler = eventHandler + self.temporaryTimer = timer } - state = .resumed - timer.resume() } - private func suspend() { - if state == .suspended { - return + func stop() { + // Must be queued synchronously to prevent threading issues. + queue.sync { + // DispatchSourceTimer will automatically cancel when released + temporaryTimer = nil + temporaryEventHandler = nil } - state = .suspended - timer.suspend() } - func fire() { - eventHandler?() + /** + True if the Timer exists and has not been cancelled. False otherwise + */ + var isValid: Bool { + guard let timer = temporaryTimer else { return false } + return !timer.isCancelled } - deinit { - timer.setEventHandler {} - timer.cancel() - /* - If the timer is suspended, calling cancel without resuming - triggers a crash. This is documented here https://forums.developer.apple.com/thread/15902 - */ - resume() - eventHandler = nil + /** + Calls the Timer's event handler immediately. This method + is primarily used in tests (not ideal) + */ + func fire() { + guard isValid else { return } + temporaryEventHandler?() } +} +extension HeartbeatTimer: Equatable { static func == (lhs: HeartbeatTimer, rhs: HeartbeatTimer) -> Bool { - return lhs.id == rhs.id + return lhs.uuid == rhs.uuid } } diff --git a/Sources/Realtime/Message.swift b/Sources/Realtime/Message.swift index a0dacd8..5047203 100644 --- a/Sources/Realtime/Message.swift +++ b/Sources/Realtime/Message.swift @@ -34,42 +34,54 @@ public class Message { /// Message event public let event: ChannelEvent + /// The raw payload from the Message, including a nested response from + /// phx_reply events. It is recommended to use `payload` instead. + internal let rawPayload: Payload + /// Message payload - public var payload: [String: Any] + public var payload: Payload { + guard let response = rawPayload["response"] as? Payload + else { return rawPayload } + return response + } /// Convenience accessor. Equivalent to getting the status as such: /// ```swift /// message.payload["status"] /// ``` - public var status: String? { - return payload["status"] as? String + public var status: PushStatus? { + guard let status = rawPayload["status"] as? String else { + return nil + } + return PushStatus(rawValue: status) } init( ref: String = "", topic: ChannelTopic = .all, event: ChannelEvent = .all, - payload: [String: Any] = [:], + payload: Payload = [:], joinRef: String? = nil ) { self.ref = ref self.topic = topic self.event = event - self.payload = payload + rawPayload = payload self.joinRef = joinRef } - init?(json: [String: Any]) { - ref = json["ref"] as? String ?? "" - joinRef = json["join_ref"] as? String + init?(json: [Any?]) { + guard json.count > 4 else { return nil } + joinRef = json[0] as? String + ref = json[1] as? String ?? "" - if let topic = json["topic"] as? String, - let event = json["event"] as? String, - let payload = json["payload"] as? [String: Any] + if let topic = (json[2] as? String).flatMap(ChannelTopic.init(rawValue:)), + let event = (json[3] as? String).flatMap(ChannelEvent.init(rawValue:)), + let payload = json[4] as? Payload { - self.topic = ChannelTopic(rawValue: topic) ?? .all - self.event = ChannelEvent(rawValue: event) ?? .all - self.payload = payload + self.topic = topic + self.event = event + rawPayload = payload } else { return nil } diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift new file mode 100644 index 0000000..67bf4c8 --- /dev/null +++ b/Sources/Realtime/Presence.swift @@ -0,0 +1,443 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// The Presence object provides features for syncing presence information from +/// the server with the client and handling presences joining and leaving. +/// +/// ## Syncing state from the server +/// +/// To sync presence state from the server, first instantiate an object and pass +/// your channel in to track lifecycle events: +/// +/// let channel = socket.channel("some:topic") +/// let presence = Presence(channel) +/// +/// If you have custom syncing state events, you can configure the `Presence` +/// object to use those instead. +/// +/// let options = Options(events: [.state: "my_state", .diff: "my_diff"]) +/// let presence = Presence(channel, opts: options) +/// +/// Next, use the presence.onSync callback to react to state changes from the +/// server. For example, to render the list of users every time the list +/// changes, you could write: +/// +/// presence.onSync { renderUsers(presence.list()) } +/// +/// ## Listing Presences +/// +/// presence.list is used to return a list of presence information based on the +/// local state of metadata. By default, all presence metadata is returned, but +/// a listBy function can be supplied to allow the client to select which +/// metadata to use for a given presence. For example, you may have a user +/// online from different devices with a metadata status of "online", but they +/// have set themselves to "away" on another device. In this case, the app may +/// choose to use the "away" status for what appears on the UI. The example +/// below defines a listBy function which prioritizes the first metadata which +/// was registered for each user. This could be the first tab they opened, or +/// the first device they came online from: +/// +/// let listBy: (String, Presence.Map) -> Presence.Meta = { id, pres in +/// let first = pres["metas"]!.first! +/// first["count"] = pres["metas"]!.count +/// first["id"] = id +/// return first +/// } +/// let onlineUsers = presence.list(by: listBy) +/// +/// (NOTE: The underlying behavior is a `map` on the `presence.state`. You are +/// mapping the `state` dictionary into whatever datastructure suites your needs) +/// +/// ## Handling individual presence join and leave events +/// +/// The presence.onJoin and presence.onLeave callbacks can be used to react to +/// individual presences joining and leaving the app. For example: +/// +/// let presence = Presence(channel) +/// presence.onJoin { [weak self] (key, current, newPres) in +/// if let cur = current { +/// print("user additional presence", cur) +/// } else { +/// print("user entered for the first time", newPres) +/// } +/// } +/// +/// presence.onLeave { [weak self] (key, current, leftPres) in +/// if current["metas"]?.isEmpty == true { +/// print("user has left from all devices", leftPres) +/// } else { +/// print("user left from a device", current) +/// } +/// } +/// +/// presence.onSync { renderUsers(presence.list()) } +public final class Presence { + // ---------------------------------------------------------------------- + + // MARK: - Enums and Structs + + // ---------------------------------------------------------------------- + /// Custom options that can be provided when creating Presence + /// + /// ### Example: + /// + /// let options = Options(events: [.state: "my_state", .diff: "my_diff"]) + /// let presence = Presence(channel, opts: options) + public struct Options { + let events: [Events: ChannelEvent] + + /// Default set of Options used when creating Presence. Uses the + /// phoenix events "presence_state" and "presence_diff" + public static let defaults = Options(events: [ + .state: .presenceState, + .diff: .presenceDiff, + ]) + + public init(events: [Events: ChannelEvent]) { + self.events = events + } + } + + /// Presense Events + public enum Events: String { + case state + case diff + } + + // ---------------------------------------------------------------------- + + // MARK: - Typaliases + + // ---------------------------------------------------------------------- + /// Meta details of a Presence. Just a dictionary of properties + public typealias Meta = [String: Any] + + /// A mapping of a String to an array of Metas. e.g. {"metas": [{id: 1}]} + public typealias Map = [String: [Meta]] + + /// A mapping of a Presence state to a mapping of Metas + public typealias State = [String: Map] + + // Diff has keys "joins" and "leaves", pointing to a Presence.State each + // containing the users that joined and left. + public typealias Diff = [String: State] + + /// Closure signature of OnJoin callbacks + public typealias OnJoin = (_ key: String, _ current: Map?, _ new: Map) -> Void + + /// Closure signature for OnLeave callbacks + public typealias OnLeave = (_ key: String, _ current: Map, _ left: Map) -> Void + + //// Closure signature for OnSync callbacks + public typealias OnSync = () -> Void + + /// Collection of callbacks with default values + struct Caller { + var onJoin: OnJoin = { _, _, _ in } + var onLeave: OnLeave = { _, _, _ in } + var onSync: OnSync = {} + } + + // ---------------------------------------------------------------------- + + // MARK: - Properties + + // ---------------------------------------------------------------------- + /// The channel the Presence belongs to + weak var channel: Channel? + + /// Caller to callback hooks + var caller: Caller + + /// The state of the Presence + public private(set) var state: State + + /// Pending `join` and `leave` diffs that need to be synced + public private(set) var pendingDiffs: [Diff] + + /// The channel's joinRef, set when state events occur + public private(set) var joinRef: String? + + public var isPendingSyncState: Bool { + guard let safeJoinRef = joinRef else { return true } + return safeJoinRef != channel?.joinRef + } + + /// Callback to be informed of joins + public var onJoin: OnJoin { + get { return caller.onJoin } + set { caller.onJoin = newValue } + } + + /// Set the OnJoin callback + public func onJoin(_ callback: @escaping OnJoin) { + onJoin = callback + } + + /// Callback to be informed of leaves + public var onLeave: OnLeave { + get { return caller.onLeave } + set { caller.onLeave = newValue } + } + + /// Set the OnLeave callback + public func onLeave(_ callback: @escaping OnLeave) { + onLeave = callback + } + + /// Callback to be informed of synces + public var onSync: OnSync { + get { return caller.onSync } + set { caller.onSync = newValue } + } + + /// Set the OnSync callback + public func onSync(_ callback: @escaping OnSync) { + onSync = callback + } + + public init(channel: Channel, opts: Options = Options.defaults) { + state = [:] + pendingDiffs = [] + self.channel = channel + joinRef = nil + caller = Caller() + + guard // Do not subscribe to events if they were not provided + let stateEvent = opts.events[.state], + let diffEvent = opts.events[.diff] + else { return } + + self.channel?.delegateOn(stateEvent, to: self) { (self, message) in + guard let newState = message.rawPayload as? State else { return } + + self.joinRef = self.channel?.joinRef + self.state = Presence.syncState( + self.state, + newState: newState, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + + self.pendingDiffs.forEach { diff in + self.state = Presence.syncDiff( + self.state, + diff: diff, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + } + + self.pendingDiffs = [] + self.caller.onSync() + } + + self.channel?.delegateOn(diffEvent, to: self) { (self, message) in + guard let diff = message.rawPayload as? Diff else { return } + if self.isPendingSyncState { + self.pendingDiffs.append(diff) + } else { + self.state = Presence.syncDiff( + self.state, + diff: diff, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + self.caller.onSync() + } + } + } + + /// Returns the array of presences, with deault selected metadata. + public func list() -> [Map] { + return list(by: { _, pres in pres }) + } + + /// Returns the array of presences, with selected metadata + public func list(by transformer: (String, Map) -> T) -> [T] { + return Presence.listBy(state, transformer: transformer) + } + + /// Filter the Presence state with a given function + public func filter(by filter: ((String, Map) -> Bool)?) -> State { + return Presence.filter(state, by: filter) + } + + // ---------------------------------------------------------------------- + + // MARK: - Static + + // ---------------------------------------------------------------------- + + // Used to sync the list of presences on the server + // with the client's state. An optional `onJoin` and `onLeave` callback can + // be provided to react to changes in the client's local presences across + // disconnects and reconnects with the server. + // + // - returns: Presence.State + @discardableResult + public static func syncState( + _ currentState: State, + newState: State, + onJoin: OnJoin = { _, _, _ in }, + onLeave: OnLeave = { _, _, _ in } + ) -> State { + let state = currentState + var leaves: Presence.State = [:] + var joins: Presence.State = [:] + + state.forEach { key, presence in + if newState[key] == nil { + leaves[key] = presence + } + } + + newState.forEach { key, newPresence in + if let currentPresence = state[key] { + let newRefs = newPresence["metas"]!.map { $0["phx_ref"] as! String } + let curRefs = currentPresence["metas"]!.map { $0["phx_ref"] as! String } + + let joinedMetas = newPresence["metas"]!.filter { (meta: Meta) -> Bool in + !curRefs.contains { $0 == meta["phx_ref"] as! String } + } + let leftMetas = currentPresence["metas"]!.filter { (meta: Meta) -> Bool in + !newRefs.contains { $0 == meta["phx_ref"] as! String } + } + + if joinedMetas.count > 0 { + joins[key] = newPresence + joins[key]!["metas"] = joinedMetas + } + + if leftMetas.count > 0 { + leaves[key] = currentPresence + leaves[key]!["metas"] = leftMetas + } + } else { + joins[key] = newPresence + } + } + + return Presence.syncDiff( + state, + diff: ["joins": joins, "leaves": leaves], + onJoin: onJoin, + onLeave: onLeave + ) + } + + // Used to sync a diff of presence join and leave + // events from the server, as they happen. Like `syncState`, `syncDiff` + // accepts optional `onJoin` and `onLeave` callbacks to react to a user + // joining or leaving from a device. + // + // - returns: Presence.State + @discardableResult + public static func syncDiff( + _ currentState: State, + diff: Diff, + onJoin: OnJoin = { _, _, _ in }, + onLeave: OnLeave = { _, _, _ in } + ) -> State { + var state = currentState + diff["joins"]?.forEach { key, newPresence in + let currentPresence = state[key] + state[key] = newPresence + + if let curPresence = currentPresence { + let joinedRefs = state[key]!["metas"]!.map { $0["phx_ref"] as! String } + let curMetas = curPresence["metas"]!.filter { (meta: Meta) -> Bool in + !joinedRefs.contains { $0 == meta["phx_ref"] as! String } + } + state[key]!["metas"]!.insert(contentsOf: curMetas, at: 0) + } + + onJoin(key, currentPresence, newPresence) + } + + diff["leaves"]?.forEach { key, leftPresence in + guard var curPresence = state[key] else { return } + let refsToRemove = leftPresence["metas"]!.map { $0["phx_ref"] as! String } + let keepMetas = curPresence["metas"]!.filter { (meta: Meta) -> Bool in + !refsToRemove.contains { $0 == meta["phx_ref"] as! String } + } + + curPresence["metas"] = keepMetas + onLeave(key, curPresence, leftPresence) + + if keepMetas.count > 0 { + state[key]!["metas"] = keepMetas + } else { + state.removeValue(forKey: key) + } + } + + return state + } + + public static func filter( + _ presences: State, + by filter: ((String, Map) -> Bool)? + ) -> State { + let safeFilter = filter ?? { _, _ in true } + return presences.filter(safeFilter) + } + + public static func listBy( + _ presences: State, + transformer: (String, Map) -> T + ) -> [T] { + return presences.map(transformer) + } +} + +extension Presence.Map { + + /// Decodes the presence metadata to an array of the specified type. + /// - parameter type: The type to decode to. + /// - parameter decoder: The decoder to use. + /// - returns: The decoded values. + /// - throws: Any error that occurs during decoding. + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> [T] { + let metas: [Presence.Meta] = self["metas"]! + let data = try JSONSerialization.data(withJSONObject: metas) + return try decoder.decode([T].self, from: data) + } + +} + +extension Presence.State { + + /// Decodes the presence metadata to a dictionary of arrays of the specified type. + /// - parameter type: The type to decode to. + /// - parameter decoder: The decoder to use. + /// - returns: The dictionary of decoded values. + /// - throws: Any error that occurs during decoding. + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> [String: [T]] { + return try mapValues { try $0.decode(decoder: decoder) } + } + +} diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift index b5ca2dc..0eb5e8b 100644 --- a/Sources/Realtime/Push.swift +++ b/Sources/Realtime/Push.swift @@ -29,7 +29,7 @@ public class Push { public let event: ChannelEvent /// The payload, for example ["user_id": "abc123"] - public var payload: [String: Any] + public var payload: Payload /// The push timeout. Default is 10.0 seconds public var timeout: TimeInterval @@ -44,7 +44,7 @@ public class Push { var timeoutWorkItem: DispatchWorkItem? /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored - var receiveHooks: [String: [Delegated]] + var receiveHooks: [PushStatus: [Delegated]] /// True if the Push has been sent var sent: Bool @@ -64,7 +64,7 @@ public class Push { init( channel: Channel, event: ChannelEvent, - payload: [String: Any] = [:], + payload: Payload = [:], timeout: TimeInterval = Defaults.timeoutInterval ) { self.channel = channel @@ -89,7 +89,7 @@ public class Push { /// Sends the Push. If it has already timed out, then the call will /// be ignored and return early. Use `resend` in this case. public func send() { - guard !hasReceived(status: "timeout") else { return } + guard !hasReceived(status: .timeout) else { return } startTimeout() sent = true @@ -120,7 +120,7 @@ public class Push { /// - parameter callback: Callback to fire when the status is recevied @discardableResult public func receive( - _ status: String, + _ status: PushStatus, callback: @escaping ((Message) -> Void) ) -> Push { var delegated = Delegated() @@ -146,7 +146,7 @@ public class Push { /// - parameter callback: Callback to fire when the status is recevied @discardableResult public func delegateReceive( - _ status: String, + _ status: PushStatus, to owner: Target, callback: @escaping ((Target, Message) -> Void) ) -> Push { @@ -158,9 +158,9 @@ public class Push { /// Shared behavior between `receive` calls @discardableResult - internal func receive(_ status: String, delegated: Delegated) -> Push { + internal func receive(_ status: PushStatus, delegated: Delegated) -> Push { // If the message has already been received, pass it to the callback immediately - if hasReceived(status: status), let receivedMessage = self.receivedMessage { + if hasReceived(status: status), let receivedMessage = receivedMessage { delegated.call(receivedMessage) } @@ -188,13 +188,13 @@ public class Push { /// /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" /// - parameter response: Response that was received - private func matchReceive(_ status: String, message: Message) { + private func matchReceive(_ status: PushStatus, message: Message) { receiveHooks[status]?.forEach { $0.call(message) } } /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push private func cancelRefEvent() { - guard let refEvent = self.refEvent else { return } + guard let refEvent = refEvent else { return } channel?.off(refEvent) } @@ -237,7 +237,7 @@ public class Push { /// Setup and start the Timeout timer. let workItem = DispatchWorkItem { - self.trigger("timeout", payload: [:]) + self.trigger(.timeout, payload: [:]) } timeoutWorkItem = workItem @@ -248,14 +248,14 @@ public class Push { /// /// - parameter status: Status to check /// - return: True if given status has been received by the Push. - internal func hasReceived(status: String) -> Bool { + internal func hasReceived(status: PushStatus) -> Bool { return receivedMessage?.status == status } /// Triggers an event to be sent though the Channel - internal func trigger(_ status: String, payload: [String: Any]) { + internal func trigger(_ status: PushStatus, payload: Payload) { /// If there is no ref event, then there is nothing to trigger on the channel - guard let refEvent = self.refEvent else { return } + guard let refEvent = refEvent else { return } var mutPayload = payload mutPayload["status"] = status diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift index 40f7ee8..3133a2d 100644 --- a/Sources/Realtime/RealtimeClient.swift +++ b/Sources/Realtime/RealtimeClient.swift @@ -1,23 +1,3 @@ -// Copyright (c) 2021 Supabase -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - // Copyright (c) 2021 David Stump // // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -40,25 +20,45 @@ import Foundation -#if canImport(FoundationNetworking) - import FoundationNetworking -#endif +public enum SocketError: Error { + case abnormalClosureError +} + +/// Alias for a JSON dictionary [String: Any] +public typealias Payload = [String: Any] -/// ## RealtimeClient +/// Alias for a function returning an optional JSON dictionary (`Payload?`) +public typealias PayloadClosure = () -> Payload? + +/// Struct that gathers callbacks assigned to the Socket +struct StateChangeCallbacks { + var open: [(ref: String, callback: Delegated)] = [] + var close: [(ref: String, callback: Delegated<(Int, String?), Void>)] = [] + var error: [(ref: String, callback: Delegated<(Error, URLResponse?), Void>)] = [] + var message: [(ref: String, callback: Delegated)] = [] +} + +/// ## Socket Connection +/// A single connection is established to the server and +/// channels are multiplexed over the connection. +/// Connect to the server using the `Socket` class: /// /// ```swift -/// let socket = new RealtimeClient("/socket", params: { ["apikey": "123" ] }) +/// let socket = new Socket("/socket", paramsClosure: { ["userToken": "123" ] }) /// socket.connect() /// ``` /// +/// The `Socket` constructor takes the mount point of the socket, +/// the authentication params, as well as options that can be found in +/// the Socket docs, such as configuring the heartbeat. public class RealtimeClient: TransportDelegate { // ---------------------------------------------------------------------- // MARK: - Public Attributes // ---------------------------------------------------------------------- - /// The string WebSocket endpoint (ie `"ws://supabase.io/realtime/v1"`, - /// `"wss://supabase.io/realtime/v1"`, etc.) That was passed to the Socket during + /// The string WebSocket endpoint (ie `"ws://example.com/socket"`, + /// `"wss://example.com"`, etc.) That was passed to the Socket during /// initialization. The URL endpoint will be modified by the Socket to /// include `"/websocket"` if missing. public let endPoint: String @@ -66,20 +66,29 @@ public class RealtimeClient: TransportDelegate { /// The fully qualified socket URL public private(set) var endPointUrl: URL - /// Resolves to return the `params` result at the time of calling. + /// Resolves to return the `paramsClosure` result at the time of calling. /// If the `Socket` was created with static params, then those will be /// returned every time. - public var params: [String: Any]? + public var params: Payload? { + return paramsClosure?() + } + + /// The optional params closure used to get params when connecting. Must + /// be set when initializing the Socket. + public let paramsClosure: PayloadClosure? + + /// The WebSocket transport. Default behavior is to provide a + /// URLSessionWebsocketTask. See README for alternatives. + private let transport: (URL) -> Transport - /// The WebSocket transport. Default behavior is to provide a Starscream - /// WebSocket instance. Potentially allows changing WebSockets in future - private let transport: Transport + /// Phoenix serializer version, defaults to "2.0.0" + public let vsn: String /// Override to provide custom encoding of data before writing to the socket - public var encode: ([String: Any]) -> Data = Defaults.encode + public var encode: (Any) -> Data = Defaults.encode - /// Override to provide customd decoding of data read from the socket - public var decode: (Data) -> [String: Any]? = Defaults.decode + /// Override to provide custom decoding of data read from the socket + public var decode: (Data) -> Any? = Defaults.decode /// Timeout to use when opening connections public var timeout: TimeInterval = Defaults.timeoutInterval @@ -87,6 +96,9 @@ public class RealtimeClient: TransportDelegate { /// Interval between sending a heartbeat public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval + /// The maximum amount of time which the system may delay heartbeats in order to optimize power usage + public var heartbeatLeeway: DispatchTimeInterval = Defaults.heartbeatLeeway + /// Interval between socket reconnect attempts, in seconds public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff @@ -121,20 +133,17 @@ public class RealtimeClient: TransportDelegate { // ---------------------------------------------------------------------- /// Callbacks for socket state changes - var stateChangeCallbacks = StateChangeCallbacks() + var stateChangeCallbacks: StateChangeCallbacks = .init() /// Collection on channels created for the Socket - var channels: [Channel] = [] + public internal(set) var channels: [Channel] = [] /// Buffers messages that need to be sent once the socket has connected. It is an array /// of tuples, with the ref of the message to send and the callback that will send the message. var sendBuffer: [(ref: String?, callback: () throws -> Void)] = [] /// Ref counter for messages - var ref = UInt64.min // 0 (max: 18,446,744,073,709,551,615) - - /// Queue to run heartbeat timer on - var heartbeatQueue = DispatchQueue(label: "com.supabase.realtime.socket.heartbeat") + var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) /// Timer that triggers sending new Heartbeat messages var heartbeatTimer: HeartbeatTimer? @@ -145,32 +154,79 @@ public class RealtimeClient: TransportDelegate { /// Timer to use when attempting to reconnect var reconnectTimer: TimeoutTimer - /// True if the Socket closed cleaned. False if not (connection timeout, heartbeat, etc) - var closeWasClean: Bool = false + /// Close status + var closeStatus: CloseStatus = .unknown /// The connection to the server - var connection: Transport? + var connection: Transport? = nil // ---------------------------------------------------------------------- // MARK: - Initialization // ---------------------------------------------------------------------- + @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) + public convenience init( + _ endPoint: String, + params: Payload? = nil, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: { params }, + vsn: vsn + ) + } + + @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) + public convenience init( + _ endPoint: String, + paramsClosure: PayloadClosure?, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: paramsClosure, + vsn: vsn + ) + } + + @available(*, deprecated, renamed: "init(_:params:vsn:)") + public convenience init( + endPoint: String, + params: Payload? = nil, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: { params }, + vsn: vsn + ) + } + public init( endPoint: String, - params: [String: Any]? = nil + transport: @escaping ((URL) -> Transport), + paramsClosure: PayloadClosure? = nil, + vsn: String = Defaults.vsn ) { + self.transport = transport + self.paramsClosure = paramsClosure + self.endPoint = endPoint + self.vsn = vsn endPointUrl = RealtimeClient.buildEndpointUrl( endpoint: endPoint, - params: params) - transport = URLSessionTransport(url: endPointUrl) - self.params = params - self.endPoint = endPoint + paramsClosure: paramsClosure, + vsn: vsn + ) reconnectTimer = TimeoutTimer() reconnectTimer.callback.delegate(to: self) { (self) in self.logItems("Socket attempting to reconnect") - self.teardown { self.connect() } + self.teardown(reason: "reconnection") { self.connect() } } reconnectTimer.timerCalculation .delegate(to: self) { (self, tries) -> TimeInterval in @@ -200,7 +256,12 @@ public class RealtimeClient: TransportDelegate { /// - return: True if the socket is connected public var isConnected: Bool { - return connection?.readyState == .open + return connectionState == .open + } + + /// - return: The state of the connect. [.connecting, .open, .closing, .closed] + public var connectionState: TransportReadyState { + return connection?.readyState ?? .closed } /// Connects the Socket. The params passed to the Socket on initialization @@ -210,17 +271,26 @@ public class RealtimeClient: TransportDelegate { // Do not attempt to reconnect if the socket is currently connected guard !isConnected else { return } - // Reset the clean close flag when attempting to connect - closeWasClean = false + // Reset the close status when attempting to connect + closeStatus = .unknown // We need to build this right before attempting to connect as the // parameters could be built upon demand and change over time endPointUrl = RealtimeClient.buildEndpointUrl( endpoint: endPoint, - params: params) + paramsClosure: paramsClosure, + vsn: vsn + ) - connection = transport + connection = transport(endPointUrl) connection?.delegate = self + // self.connection?.disableSSLCertValidation = disableSSLCertValidation + // + // #if os(Linux) + // #else + // self.connection?.security = security + // self.connection?.enabledSSLCipherSuites = enabledSSLCipherSuites + // #endif connection?.connect() } @@ -228,31 +298,33 @@ public class RealtimeClient: TransportDelegate { /// Disconnects the socket /// /// - parameter code: Optional. Closing status code - /// - paramter callback: Optional. Called when disconnected + /// - parameter callback: Optional. Called when disconnected public func disconnect( code: CloseCode = CloseCode.normal, + reason: String? = nil, callback: (() -> Void)? = nil ) { // The socket was closed cleanly by the User - closeWasClean = true + closeStatus = CloseStatus(closeCode: code.rawValue) // Reset any reconnects and teardown the socket connection reconnectTimer.reset() - teardown(code: code, callback: callback) + teardown(code: code, reason: reason, callback: callback) } - internal func teardown(code: CloseCode = CloseCode.normal, callback: (() -> Void)? = nil) { + internal func teardown( + code: CloseCode = CloseCode.normal, reason: String? = nil, callback: (() -> Void)? = nil + ) { connection?.delegate = nil - connection?.disconnect(code: code.rawValue, reason: nil) + connection?.disconnect(code: code.rawValue, reason: reason) connection = nil // The socket connection has been torndown, heartbeats are not needed - heartbeatTimer?.stopTimer() - heartbeatTimer = nil + heartbeatTimer?.stop() // Since the connection's delegate was nil'd out, inform all state // callbacks that the connection has closed - stateChangeCallbacks.close.forEach { $0.callback.call() } + stateChangeCallbacks.close.forEach { $0.callback.call((code.rawValue, reason)) } callback?() } @@ -274,7 +346,22 @@ public class RealtimeClient: TransportDelegate { /// - parameter callback: Called when the Socket is opened @discardableResult public func onOpen(callback: @escaping () -> Void) -> String { - var delegated = Delegated() + return onOpen { _ in callback() } + } + + /// Registers callbacks for connection open events. Does not handle retain + /// cycles. Use `delegateOnOpen(to:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onOpen() { [weak self] response in + /// self?.print("Socket Connection Open") + /// } + /// + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func onOpen(callback: @escaping (URLResponse?) -> Void) -> String { + var delegated = Delegated() delegated.manuallyDelegate(with: callback) return append(callback: delegated, to: &stateChangeCallbacks.open) @@ -296,7 +383,26 @@ public class RealtimeClient: TransportDelegate { to owner: T, callback: @escaping ((T) -> Void) ) -> String { - var delegated = Delegated() + return delegateOnOpen(to: owner) { owner, _ in callback(owner) } + } + + /// Registers callbacks for connection open events. Automatically handles + /// retain cycles. Use `onOpen()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnOpen(to: self) { self, response in + /// self.print("Socket Connection Open") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func delegateOnOpen( + to owner: T, + callback: @escaping ((T, URLResponse?) -> Void) + ) -> String { + var delegated = Delegated() delegated.delegate(to: owner, with: callback) return append(callback: delegated, to: &stateChangeCallbacks.open) @@ -314,7 +420,22 @@ public class RealtimeClient: TransportDelegate { /// - parameter callback: Called when the Socket is closed @discardableResult public func onClose(callback: @escaping () -> Void) -> String { - var delegated = Delegated() + return onClose { _, _ in callback() } + } + + /// Registers callbacks for connection close events. Does not handle retain + /// cycles. Use `delegateOnClose(_:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onClose() { [weak self] code, reason in + /// self?.print("Socket Connection Close") + /// } + /// + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func onClose(callback: @escaping (Int, String?) -> Void) -> String { + var delegated = Delegated<(Int, String?), Void>() delegated.manuallyDelegate(with: callback) return append(callback: delegated, to: &stateChangeCallbacks.close) @@ -336,7 +457,26 @@ public class RealtimeClient: TransportDelegate { to owner: T, callback: @escaping ((T) -> Void) ) -> String { - var delegated = Delegated() + return delegateOnClose(to: owner) { owner, _ in callback(owner) } + } + + /// Registers callbacks for connection close events. Automatically handles + /// retain cycles. Use `onClose()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnClose(self) { self, code, reason in + /// self.print("Socket Connection Close") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func delegateOnClose( + to owner: T, + callback: @escaping ((T, (Int, String?)) -> Void) + ) -> String { + var delegated = Delegated<(Int, String?), Void>() delegated.delegate(to: owner, with: callback) return append(callback: delegated, to: &stateChangeCallbacks.close) @@ -353,8 +493,8 @@ public class RealtimeClient: TransportDelegate { /// /// - parameter callback: Called when the Socket errors @discardableResult - public func onError(callback: @escaping (Error) -> Void) -> String { - var delegated = Delegated() + public func onError(callback: @escaping ((Error, URLResponse?)) -> Void) -> String { + var delegated = Delegated<(Error, URLResponse?), Void>() delegated.manuallyDelegate(with: callback) return append(callback: delegated, to: &stateChangeCallbacks.error) @@ -374,9 +514,9 @@ public class RealtimeClient: TransportDelegate { @discardableResult public func delegateOnError( to owner: T, - callback: @escaping ((T, Error) -> Void) + callback: @escaping ((T, (Error, URLResponse?)) -> Void) ) -> String { - var delegated = Delegated() + var delegated = Delegated<(Error, URLResponse?), Void>() delegated.delegate(to: owner, with: callback) return append(callback: delegated, to: &stateChangeCallbacks.error) @@ -442,7 +582,24 @@ public class RealtimeClient: TransportDelegate { // ---------------------------------------------------------------------- // MARK: - Channel Initialization - + // ---------------------------------------------------------------------- + /// Initialize a new Channel + /// + /// Example: + /// + /// let channel = socket.channel("rooms", options: ChannelOptions(presenceKey: "user123")) + /// + /// - parameter topic: Topic of the channel + /// - parameter options: Optional. Options to configure channel broadcast and presence. Leave nil for postgres channel. + /// - return: A new channel + public func channel( + _ topic: ChannelTopic, + options: ChannelOptions? = nil + ) -> Channel { + let channel = Channel(topic: topic, options: options, socket: self) + channels.append(channel) + return channel + } // ---------------------------------------------------------------------- /// Initialize a new Channel /// @@ -453,9 +610,10 @@ public class RealtimeClient: TransportDelegate { /// - parameter topic: Topic of the channel /// - parameter params: Optional. Parameters for the channel /// - return: A new channel + @available(*, deprecated, renamed: "channel(_:options:)") public func channel( _ topic: ChannelTopic, - params: [String: Any] = [:] + params: [String: Any] ) -> Channel { let channel = Channel(topic: topic, params: params, socket: self) channels.append(channel) @@ -483,10 +641,18 @@ public class RealtimeClient: TransportDelegate { /// /// - Parameter refs: List of refs returned by calls to `onOpen`, `onClose`, etc public func off(_ refs: [String]) { - stateChangeCallbacks.open = stateChangeCallbacks.open.filter { !refs.contains($0.ref) } - stateChangeCallbacks.close = stateChangeCallbacks.close.filter { !refs.contains($0.ref) } - stateChangeCallbacks.error = stateChangeCallbacks.error.filter { !refs.contains($0.ref) } - stateChangeCallbacks.message = stateChangeCallbacks.message.filter { !refs.contains($0.ref) } + stateChangeCallbacks.open = stateChangeCallbacks.open.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.close = stateChangeCallbacks.close.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.error = stateChangeCallbacks.error.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.message = stateChangeCallbacks.message.filter { + !refs.contains($0.ref) + } } // ---------------------------------------------------------------------- @@ -506,20 +672,12 @@ public class RealtimeClient: TransportDelegate { internal func push( topic: ChannelTopic, event: ChannelEvent, - payload: [String: Any], + payload: Payload, ref: String? = nil, joinRef: String? = nil ) { let callback: (() throws -> Void) = { - var body: [String: Any] = [ - "topic": topic.rawValue, - "event": event.rawValue, - "payload": payload, - ] - - if let safeRef = ref { body["ref"] = safeRef } - if let safeJoinRef = joinRef { body["join_ref"] = safeJoinRef } - + let body: [Any?] = [joinRef, ref, topic.rawValue, event.rawValue, payload] let data = self.encode(body) self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") @@ -544,7 +702,7 @@ public class RealtimeClient: TransportDelegate { /// Logs the message. Override Socket.logger for specialized logging. noops by default /// - /// - paramter items: List of items to be logged. Behaves just like debugPrint() + /// - parameter items: List of items to be logged. Behaves just like debugPrint() func logItems(_ items: Any...) { let msg = items.map { String(describing: $0) }.joined(separator: ", ") logger?("SwiftPhoenixClient: \(msg)") @@ -556,11 +714,11 @@ public class RealtimeClient: TransportDelegate { // ---------------------------------------------------------------------- /// Called when the underlying Websocket connects to it's host - internal func onConnectionOpen() { + internal func onConnectionOpen(response: URLResponse?) { logItems("transport", "Connected to \(endPoint)") - // Reset the closeWasClean flag now that the socket has been connected - closeWasClean = false + // Reset the close status now that the socket has been connected + closeStatus = .unknown // Send any messages that were waiting for a connection flushSendBuffer() @@ -572,33 +730,35 @@ public class RealtimeClient: TransportDelegate { resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - stateChangeCallbacks.open.forEach { $0.callback.call() } + stateChangeCallbacks.open.forEach { $0.callback.call(response) } } - internal func onConnectionClosed(code _: Int?) { + internal func onConnectionClosed(code: Int, reason: String?) { logItems("transport", "close") + + // Send an error to all channels triggerChannelError() // Prevent the heartbeat from triggering if the - heartbeatTimer?.stopTimer() - heartbeatTimer = nil + heartbeatTimer?.stop() - // Only attempt to reconnect if the socket did not close normally - if !closeWasClean { + // Only attempt to reconnect if the socket did not close normally, + // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) + if closeStatus.shouldReconnect { reconnectTimer.scheduleTimeout() } - stateChangeCallbacks.close.forEach { $0.callback.call() } + stateChangeCallbacks.close.forEach { $0.callback.call((code, reason)) } } - internal func onConnectionError(_ error: Error) { - logItems("transport", error) + internal func onConnectionError(_ error: Error, response: URLResponse?) { + logItems("transport", error, response ?? "") // Send an error to all channels triggerChannelError() - // Inform any state callabcks of the error - stateChangeCallbacks.error.forEach { $0.callback.call(error) } + // Inform any state callbacks of the error + stateChangeCallbacks.error.forEach { $0.callback.call((error, response)) } } internal func onConnectionMessage(_ rawMessage: String) { @@ -606,7 +766,7 @@ public class RealtimeClient: TransportDelegate { guard let data = rawMessage.data(using: String.Encoding.utf8), - let json = decode(data), + let json = decode(data) as? [Any?], let message = Message(json: json) else { logItems("receive: Unable to parse JSON: \(rawMessage)") @@ -641,7 +801,7 @@ public class RealtimeClient: TransportDelegate { /// Send all messages that were buffered before the socket opened internal func flushSendBuffer() { - guard isConnected, sendBuffer.count > 0 else { return } + guard isConnected && sendBuffer.count > 0 else { return } sendBuffer.forEach { try? $0.callback() } sendBuffer = [] } @@ -652,7 +812,9 @@ public class RealtimeClient: TransportDelegate { } /// Builds a fully qualified socket `URL` from `endPoint` and `params`. - internal static func buildEndpointUrl(endpoint: String, params: [String: Any]?) -> URL { + internal static func buildEndpointUrl( + endpoint: String, paramsClosure params: PayloadClosure?, vsn: String + ) -> URL { guard let url = URL(string: endpoint), var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) @@ -669,11 +831,14 @@ public class RealtimeClient: TransportDelegate { urlComponents.path.append("websocket") } + urlComponents.queryItems = [URLQueryItem(name: "vsn", value: vsn)] + // If there are parameters, append them to the URL - if let params = params { - urlComponents.queryItems = params.map { - URLQueryItem(name: $0.key, value: String(describing: $0.value)) - } + if let params = params?() { + urlComponents.queryItems?.append( + contentsOf: params.map { + URLQueryItem(name: $0.key, value: String(describing: $0.value)) + }) } guard let qualifiedUrl = urlComponents.url @@ -688,7 +853,7 @@ public class RealtimeClient: TransportDelegate { else { return } logItems("transport", "leaving duplicate topic: [\(topic)]") - dupe.unsubscribe() + dupe.leave() } // ---------------------------------------------------------------------- @@ -699,19 +864,18 @@ public class RealtimeClient: TransportDelegate { internal func resetHeartbeat() { // Clear anything related to the heartbeat pendingHeartbeatRef = nil - heartbeatTimer?.stopTimer() - heartbeatTimer = nil + heartbeatTimer?.stop() // Do not start up the heartbeat timer if skipHeartbeat is true guard !skipHeartbeat else { return } - heartbeatTimer = HeartbeatTimer(timeInterval: heartbeatInterval, dispatchQueue: heartbeatQueue) - heartbeatTimer?.startTimerWithEvent(eventHandler: { [weak self] in + heartbeatTimer = HeartbeatTimer(timeInterval: heartbeatInterval, leeway: heartbeatLeeway) + heartbeatTimer?.start(eventHandler: { [weak self] in self?.sendHeartbeat() }) } - /// Sends a hearbeat payload to the phoenix serverss + /// Sends a heartbeat payload to the phoenix servers @objc func sendHeartbeat() { // Do not send if the connection is closed guard isConnected else { return } @@ -723,7 +887,8 @@ public class RealtimeClient: TransportDelegate { pendingHeartbeatRef = nil logItems( "transport", - "heartbeat timeout. Attempting to re-establish connection") + "heartbeat timeout. Attempting to re-establish connection" + ) // Close the socket manually, flagging the closure as abnormal. Do not use // `teardown` or `disconnect` as they will nil out the websocket delegate. @@ -738,16 +903,20 @@ public class RealtimeClient: TransportDelegate { topic: .heartbeat, event: ChannelEvent.heartbeat, payload: [:], - ref: pendingHeartbeatRef) + ref: pendingHeartbeatRef + ) } internal func abnormalClose(_ reason: String) { - closeWasClean = false + closeStatus = .abnormal /* We use NORMAL here since the client is the one determining to close the - connection. However, we keep a flag `closeWasClean` set to false so that + connection. However, we set to close status to abnormal so that the client knows that it should attempt to reconnect. + + If the server subsequently acknowledges with code 1000 (normal close), + the socket will keep the `.abnormal` close status and trigger a reconnection. */ connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) } @@ -757,21 +926,21 @@ public class RealtimeClient: TransportDelegate { // MARK: - TransportDelegate // ---------------------------------------------------------------------- - public func onOpen() { - onConnectionOpen() + public func onOpen(response: URLResponse?) { + onConnectionOpen(response: response) } - public func onError(error: Error) { - onConnectionError(error) + public func onError(error: Error, response: URLResponse?) { + onConnectionError(error, response: response) } public func onMessage(message: String) { onConnectionMessage(message) } - public func onClose(code: Int) { - closeWasClean = code != CloseCode.abnormal.rawValue - onConnectionClosed(code: code) + public func onClose(code: Int, reason: String? = nil) { + closeStatus.update(transportCloseCode: code) + onConnectionClosed(code: code, reason: reason) } } @@ -789,3 +958,58 @@ extension RealtimeClient { case goingAway = 1001 } } + +// ---------------------------------------------------------------------- + +// MARK: - Close Status + +// ---------------------------------------------------------------------- +extension RealtimeClient { + /// Indicates the different closure states a socket can be in. + enum CloseStatus { + /// Undetermined closure state + case unknown + /// A clean closure requested either by the client or the server + case clean + /// An abnormal closure requested by the client + case abnormal + + /// Temporarily close the socket, pausing reconnect attempts. Useful on mobile + /// clients when disconnecting a because the app resigned active but should + /// reconnect when app enters active state. + case temporary + + init(closeCode: Int) { + switch closeCode { + case CloseCode.abnormal.rawValue: + self = .abnormal + case CloseCode.goingAway.rawValue: + self = .temporary + default: + self = .clean + } + } + + mutating func update(transportCloseCode: Int) { + switch self { + case .unknown, .clean, .temporary: + // Allow transport layer to override these statuses. + self = .init(closeCode: transportCloseCode) + case .abnormal: + // Do not allow transport layer to override the abnormal close status. + // The socket itself should reset it on the next connection attempt. + // See `Socket.abnormalClose(_:)` for more information. + break + } + } + + var shouldReconnect: Bool { + switch self { + case .unknown, .abnormal: + return true + case .clean, .temporary: + return false + } + } + } +} diff --git a/Sources/Realtime/SocketError.swift b/Sources/Realtime/SocketError.swift deleted file mode 100644 index e23911b..0000000 --- a/Sources/Realtime/SocketError.swift +++ /dev/null @@ -1,3 +0,0 @@ -public enum SocketError: Error { - case abnormalClosureError -} diff --git a/Sources/Realtime/StateChangeCallbacks.swift b/Sources/Realtime/StateChangeCallbacks.swift deleted file mode 100644 index 20e99ce..0000000 --- a/Sources/Realtime/StateChangeCallbacks.swift +++ /dev/null @@ -1,7 +0,0 @@ -/// Struct that gathers callbacks assigned to the Socket -struct StateChangeCallbacks { - var open: [(ref: String, callback: Delegated)] = [] - var close: [(ref: String, callback: Delegated)] = [] - var error: [(ref: String, callback: Delegated)] = [] - var message: [(ref: String, callback: Delegated)] = [] -} diff --git a/Sources/Realtime/SynchronizedArray.swift b/Sources/Realtime/SynchronizedArray.swift new file mode 100644 index 0000000..e7345ce --- /dev/null +++ b/Sources/Realtime/SynchronizedArray.swift @@ -0,0 +1,33 @@ +// +// SynchronizedArray.swift +// SwiftPhoenixClient +// +// Created by Daniel Rees on 4/12/23. +// Copyright © 2023 SwiftPhoenixClient. All rights reserved. +// + +import Foundation + +/// A thread-safe array. +public class SynchronizedArray { + fileprivate let queue = DispatchQueue(label: "spc_sync_array", attributes: .concurrent) + fileprivate var array = [Element]() + + func append(_ newElement: Element) { + queue.async(flags: .barrier) { + self.array.append(newElement) + } + } + + func removeAll(where shouldBeRemoved: @escaping (Element) -> Bool) { + queue.async(flags: .barrier) { + self.array.removeAll(where: shouldBeRemoved) + } + } + + func filter(_ isIncluded: (Element) -> Bool) -> [Element] { + var result = [Element]() + queue.sync { result = self.array.filter(isIncluded) } + return result + } +} diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift index 9882d61..b6b37c4 100644 --- a/Sources/Realtime/TimeoutTimer.swift +++ b/Sources/Realtime/TimeoutTimer.swift @@ -57,7 +57,7 @@ class TimeoutTimer { var tries: Int = 0 /// The Queue to execute on. In testing, this is overridden - var queue = TimerQueue.main + var queue: TimerQueue = .main /// Resets the Timer, clearing the number of tries and stops /// any scheduled timeout. diff --git a/Sources/Realtime/Transport.swift b/Sources/Realtime/Transport.swift index 0e44bb7..92f7364 100644 --- a/Sources/Realtime/Transport.swift +++ b/Sources/Realtime/Transport.swift @@ -20,10 +20,6 @@ import Foundation -#if canImport(FoundationNetworking) - import FoundationNetworking -#endif - // ---------------------------------------------------------------------- // MARK: - Transport Protocol @@ -69,15 +65,19 @@ public protocol Transport { public protocol TransportDelegate { /** Notified when the `Transport` opens. + + - Parameter response: Response from the server indicating that the WebSocket handshake was successful and the connection has been upgraded to webSockets */ - func onOpen() + func onOpen(response: URLResponse?) /** Notified when the `Transport` receives an error. - - Parameter error: Error from the underlying `Transport` implementation + - Parameter error: Client-side error from the underlying `Transport` implementation + - Parameter response: Response from the server, if any, that occurred with the Error + */ - func onError(error: Error) + func onError(error: Error, response: URLResponse?) /** Notified when the `Transport` receives a message from the server. @@ -90,8 +90,9 @@ public protocol TransportDelegate { Notified when the `Transport` closes. - Parameter code: Code that was sent when the `Transport` closed + - Parameter reason: A concise human-readable prose explanation for the closure */ - func onClose(code: Int) + func onClose(code: Int, reason: String?) } // ---------------------------------------------------------------------- @@ -122,20 +123,23 @@ public enum TransportReadyState { /// A `Transport` implementation that relies on URLSession's native WebSocket /// implementation. /// -/// This implementation ships default with SwiftClient however -/// SwiftClient supports earlier OS versions using one of the submodule +/// This implementation ships default with SwiftPhoenixClient however +/// SwiftPhoenixClient supports earlier OS versions using one of the submodule /// `Transport` implementations. Or you can create your own implementation using /// your own WebSocket library or implementation. @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) -public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelegate { +open class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelegate { /// The URL to connect to internal let url: URL - /// The underling URLsession. Assigned during `connect()` - private var session: URLSession? + /// The URLSession configuration + internal let configuration: URLSessionConfiguration + + /// The underling URLSession. Assigned during `connect()` + private var session: URLSession? = nil /// The ongoing task. Assigned during `connect()` - private var task: URLSessionWebSocketTask? + private var task: URLSessionWebSocketTask? = nil /** Initializes a `Transport` layer built using URLSession's WebSocket @@ -147,9 +151,18 @@ public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelega let transport: Transport = URLSessionTransport(url: url) ``` + Using a custom `URLSessionConfiguration` + + ```swift + let url = URL("wss://example.com/socket") + let configuration = URLSessionConfiguration.default + let transport: Transport = URLSessionTransport(url: url, configuration: configuration) + ``` + - parameter url: URL to connect to + - parameter configuration: Provide your own URLSessionConfiguration. Uses `.default` if none provided */ - init(url: URL) { + public init(url: URL, configuration: URLSessionConfiguration = .default) { // URLSession requires that the endpoint be "wss" instead of "https". let endpoint = url.absoluteString let wsEndpoint = @@ -160,6 +173,7 @@ public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelega // Force unwrapping should be safe here since a valid URL came in and we just // replaced the protocol. self.url = URL(string: wsEndpoint)! + self.configuration = configuration super.init() } @@ -167,21 +181,21 @@ public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelega // MARK: - Transport public var readyState: TransportReadyState = .closed - public var delegate: TransportDelegate? + public var delegate: TransportDelegate? = nil - public func connect() { - // Set the trasport state as connecting + open func connect() { + // Set the transport state as connecting readyState = .connecting // Create the session and websocket task - session = URLSession(configuration: .default, delegate: self, delegateQueue: OperationQueue()) + session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) task = session?.webSocketTask(with: url) // Start the task task?.resume() } - public func disconnect(code: Int, reason: String?) { + open func disconnect(code: Int, reason: String?) { /* TODO: 1. Provide a "strict" mode that fails if an invalid close code is given @@ -194,62 +208,66 @@ public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelega readyState = .closing task?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) + session?.finishTasksAndInvalidate() } - public func send(data: Data) { - task?.send(.data(data)) { _ in + open func send(data: Data) { + task?.send(.string(String(data: data, encoding: .utf8)!)) { _ in // TODO: What is the behavior when an error occurs? } } // MARK: - URLSessionWebSocketDelegate - public func urlSession( + open func urlSession( _: URLSession, - webSocketTask _: URLSessionWebSocketTask, + webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol _: String? ) { // The Websocket is connected. Set Transport state to open and inform delegate readyState = .open - delegate?.onOpen() + delegate?.onOpen(response: webSocketTask.response) // Start receiving messages receive() } - public func urlSession( + open func urlSession( _: URLSession, webSocketTask _: URLSessionWebSocketTask, didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, - reason _: Data? + reason: Data? ) { // A close frame was received from the server. readyState = .closed - delegate?.onClose(code: closeCode.rawValue) + delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) } - public func urlSession( + open func urlSession( _: URLSession, - task _: URLSessionTask, + task: URLSessionTask, didCompleteWithError error: Error? ) { // The task has terminated. Inform the delegate that the transport has closed abnormally // if this was caused by an error. guard let err = error else { return } - abnormalErrorReceived(err) + + abnormalErrorReceived(err, response: task.response) } // MARK: - Private private func receive() { - task?.receive { result in + task?.receive { [weak self] result in switch result { case let .success(message): switch message { case .data: print("Data received. This method is unsupported by the Client") case let .string(text): - self.delegate?.onMessage(message: text) + self?.delegate?.onMessage(message: text) default: fatalError("Unknown result was received. [\(result)]") } @@ -257,24 +275,26 @@ public class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelega // Since `.receive()` is only good for a single message, it must // be called again after a message is received in order to // received the next message. - self.receive() + self?.receive() case let .failure(error): print("Error when receiving \(error)") - self.abnormalErrorReceived(error) + self?.abnormalErrorReceived(error, response: nil) } } } - private func abnormalErrorReceived(_ error: Error) { + private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { // Set the state of the Transport to closed readyState = .closed // Inform the Transport's delegate that an error occurred. - delegate?.onError(error: error) + delegate?.onError(error: error, response: response) // An abnormal error is results in an abnormal closure, such as internet getting dropped // so inform the delegate that the Transport has closed abnormally. This will kick off // the reconnect logic. - delegate?.onClose(code: RealtimeClient.CloseCode.abnormal.rawValue) + delegate?.onClose( + code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription + ) } } diff --git a/Tests/RealtimeTests/ChannelTopicTests.swift b/Tests/RealtimeTests/ChannelTopicTests.swift index 27b7f5d..5a21bfd 100644 --- a/Tests/RealtimeTests/ChannelTopicTests.swift +++ b/Tests/RealtimeTests/ChannelTopicTests.swift @@ -3,7 +3,6 @@ import XCTest @testable import Realtime final class ChannelTopicTests: XCTestCase { - func testRawValue() { XCTAssertEqual(ChannelTopic.all, ChannelTopic(rawValue: "realtime:*")) XCTAssertEqual(ChannelTopic.all, ChannelTopic(rawValue: "*")) @@ -13,7 +12,8 @@ final class ChannelTopicTests: XCTestCase { ) XCTAssertEqual( ChannelTopic.column("email", value: "mail@supabase.io", table: "users", schema: "public"), - ChannelTopic(rawValue: "realtime:public:users:email=eq.mail@supabase.io")) + ChannelTopic(rawValue: "realtime:public:users:email=eq.mail@supabase.io") + ) XCTAssertEqual(ChannelTopic.heartbeat, ChannelTopic(rawValue: "phoenix")) } } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index 4a39aa9..b8a6868 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -23,10 +23,12 @@ final class RealtimeTests: XCTestCase { func testConnection() throws { try XCTSkipIf( ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, - "INTEGRATION_TESTS not defined") + "INTEGRATION_TESTS not defined" + ) let socket = RealtimeClient( - endPoint: "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey]) + "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] + ) let e = expectation(description: "testConnection") socket.onOpen { @@ -36,7 +38,7 @@ final class RealtimeTests: XCTestCase { } } - socket.onError { error in + socket.onError { error, _ in XCTFail(error.localizedDescription) } @@ -57,32 +59,34 @@ final class RealtimeTests: XCTestCase { func testChannelCreation() throws { try XCTSkipIf( ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, - "INTEGRATION_TESTS not defined") + "INTEGRATION_TESTS not defined" + ) let client = RealtimeClient( - endPoint: "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey]) + "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] + ) let allChanges = client.channel(.all) allChanges.on(.all) { message in print(message) } - allChanges.subscribe() - allChanges.unsubscribe() + allChanges.join() + allChanges.leave() allChanges.off(.all) let allPublicInsertChanges = client.channel(.schema("public")) allPublicInsertChanges.on(.insert) { message in print(message) } - allPublicInsertChanges.subscribe() - allPublicInsertChanges.unsubscribe() + allPublicInsertChanges.join() + allPublicInsertChanges.leave() allPublicInsertChanges.off(.insert) let allUsersUpdateChanges = client.channel(.table("users", schema: "public")) allUsersUpdateChanges.on(.update) { message in print(message) } - allUsersUpdateChanges.subscribe() - allUsersUpdateChanges.unsubscribe() + allUsersUpdateChanges.join() + allUsersUpdateChanges.leave() allUsersUpdateChanges.off(.update) let allUserId99Changes = client.channel( @@ -90,13 +94,13 @@ final class RealtimeTests: XCTestCase { allUserId99Changes.on(.all) { message in print(message) } - allUserId99Changes.subscribe() - allUserId99Changes.unsubscribe() + allUserId99Changes.join() + allUserId99Changes.leave() allUserId99Changes.off(.all) XCTAssertEqual(client.isConnected, false) - let e = expectation(description: self.name) + let e = expectation(description: name) client.onOpen { XCTAssertEqual(client.isConnected, true) DispatchQueue.main.asyncAfter(deadline: .now() + 1) { @@ -104,7 +108,7 @@ final class RealtimeTests: XCTestCase { } } - client.onError { error in + client.onError { error, _ in XCTFail(error.localizedDescription) }