1
- use core:: fmt;
2
-
3
1
use crate :: {
4
2
pubsub:: { In , JsonSink , Listener , Out } ,
5
3
types:: InboundData ,
4
+ HandlerCtx , TaskSet ,
6
5
} ;
6
+ use core:: fmt;
7
7
use serde_json:: value:: RawValue ;
8
- use tokio:: {
9
- select,
10
- sync:: { mpsc, oneshot, watch} ,
11
- task:: JoinHandle ,
12
- } ;
8
+ use tokio:: { pin, select, sync:: mpsc, task:: JoinHandle } ;
13
9
use tokio_stream:: StreamExt ;
10
+ use tokio_util:: sync:: WaitForCancellationFutureOwned ;
14
11
use tracing:: { debug, debug_span, error, instrument, trace, Instrument } ;
15
12
16
13
/// Default notification buffer size per task.
@@ -19,18 +16,6 @@ pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16;
19
16
/// Type alias for identifying connections.
20
17
pub type ConnectionId = u64 ;
21
18
22
- /// Holds the shutdown signal for some server.
23
- #[ derive( Debug ) ]
24
- pub struct ServerShutdown {
25
- pub ( crate ) _shutdown : watch:: Sender < ( ) > ,
26
- }
27
-
28
- impl From < watch:: Sender < ( ) > > for ServerShutdown {
29
- fn from ( sender : watch:: Sender < ( ) > ) -> Self {
30
- Self { _shutdown : sender }
31
- }
32
- }
33
-
34
19
/// The `ListenerTask` listens for new connections, and spawns `RouteTask`s for
35
20
/// each.
36
21
pub ( crate ) struct ListenerTask < T : Listener > {
@@ -67,16 +52,17 @@ where
67
52
}
68
53
69
54
/// Spawn the future produced by [`Self::task_future`].
70
- pub ( crate ) fn spawn ( self ) -> JoinHandle < ( ) > {
55
+ pub ( crate ) fn spawn ( self ) -> JoinHandle < Option < ( ) > > {
56
+ let tasks = self . manager . root_tasks . clone ( ) ;
71
57
let future = self . task_future ( ) ;
72
- tokio :: spawn ( future)
58
+ tasks . spawn_cancellable ( future)
73
59
}
74
60
}
75
61
76
62
/// The `ConnectionManager` provides connections with IDs, and handles spawning
77
63
/// the [`RouteTask`] for each connection.
78
64
pub ( crate ) struct ConnectionManager {
79
- pub ( crate ) shutdown : watch :: Receiver < ( ) > ,
65
+ pub ( crate ) root_tasks : TaskSet ,
80
66
81
67
pub ( crate ) next_id : ConnectionId ,
82
68
@@ -107,19 +93,18 @@ impl ConnectionManager {
107
93
) -> ( RouteTask < T > , WriteTask < T > ) {
108
94
let ( tx, rx) = mpsc:: channel ( self . notification_buffer_per_task ) ;
109
95
110
- let ( gone_tx , gone_rx ) = oneshot :: channel ( ) ;
96
+ let tasks = self . root_tasks . child ( ) ;
111
97
112
98
let rt = RouteTask {
113
99
router : self . router ( ) ,
114
100
conn_id,
115
101
write_task : tx,
116
102
requests,
117
- gone : gone_tx ,
103
+ tasks : tasks . clone ( ) ,
118
104
} ;
119
105
120
106
let wt = WriteTask {
121
- shutdown : self . shutdown . clone ( ) ,
122
- gone : gone_rx,
107
+ tasks,
123
108
conn_id,
124
109
json : rx,
125
110
connection,
@@ -156,8 +141,8 @@ struct RouteTask<T: crate::pubsub::Listener> {
156
141
pub ( crate ) write_task : mpsc:: Sender < Box < RawValue > > ,
157
142
/// Stream of requests.
158
143
pub ( crate ) requests : In < T > ,
159
- /// Sender to the [`WriteTask`], to notify it that this task is done.
160
- pub ( crate ) gone : oneshot :: Sender < ( ) > ,
144
+ /// The task set for this connection
145
+ pub ( crate ) tasks : TaskSet ,
161
146
}
162
147
163
148
impl < T : crate :: pubsub:: Listener > fmt:: Debug for RouteTask < T > {
@@ -179,18 +164,27 @@ where
179
164
/// to handle the request, and given a sender to the [`WriteTask`]. This
180
165
/// ensures that requests can be handled concurrently.
181
166
#[ instrument( name = "RouteTask" , skip( self ) , fields( conn_id = self . conn_id) ) ]
182
- pub async fn task_future ( self ) {
167
+ pub async fn task_future ( self , cancel : WaitForCancellationFutureOwned ) {
183
168
let RouteTask {
184
169
router,
185
170
mut requests,
186
171
write_task,
187
- gone ,
172
+ tasks ,
188
173
..
189
174
} = self ;
190
175
176
+ // The write task is responsible for waiting for its children
177
+ let children = tasks. child ( ) ;
178
+
179
+ pin ! ( cancel) ;
180
+
191
181
loop {
192
182
select ! {
193
183
biased;
184
+ _ = & mut cancel => {
185
+ debug!( "RouteTask cancelled" ) ;
186
+ break ;
187
+ }
194
188
_ = write_task. closed( ) => {
195
189
debug!( "WriteTask has gone away" ) ;
196
190
break ;
@@ -208,7 +202,11 @@ where
208
202
209
203
let span = debug_span!( "pubsub request handling" , reqs = reqs. len( ) ) ;
210
204
211
- let ctx = write_task. clone( ) . into( ) ;
205
+ let ctx =
206
+ HandlerCtx :: new(
207
+ Some ( write_task. clone( ) ) ,
208
+ children. clone( ) ,
209
+ ) ;
212
210
213
211
let fut = router. handle_request_batch( ctx, reqs) ;
214
212
let write_task = write_task. clone( ) ;
@@ -223,7 +221,7 @@ where
223
221
} ;
224
222
225
223
// Run the future in a new task.
226
- tokio :: spawn (
224
+ children . spawn_cancellable (
227
225
async move {
228
226
// Send the response to the write task.
229
227
// we don't care if the receiver has gone away,
@@ -239,27 +237,23 @@ where
239
237
}
240
238
}
241
239
}
242
- // No funny business. Drop the gone signal.
243
- drop ( gone) ;
240
+ children. shutdown ( ) . await ;
244
241
}
245
242
246
243
/// Spawn the future produced by [`Self::task_future`].
247
244
pub ( crate ) fn spawn ( self ) -> tokio:: task:: JoinHandle < ( ) > {
248
- let future = self . task_future ( ) ;
249
- tokio:: spawn ( future)
245
+ let tasks = self . tasks . clone ( ) ;
246
+
247
+ let future = move |cancel| self . task_future ( cancel) ;
248
+
249
+ tasks. spawn_graceful ( future)
250
250
}
251
251
}
252
252
253
253
/// The Write Task is responsible for writing JSON to the outbound connection.
254
254
struct WriteTask < T : Listener > {
255
- /// Shutdown signal.
256
- ///
257
- /// Shutdowns bubble back up to [`RouteTask`] when the write task is
258
- /// dropped, via the closed `json` channel.
259
- pub ( crate ) shutdown : watch:: Receiver < ( ) > ,
260
-
261
- /// Signal that the connection has gone away.
262
- pub ( crate ) gone : oneshot:: Receiver < ( ) > ,
255
+ /// Task set
256
+ pub ( crate ) tasks : TaskSet ,
263
257
264
258
/// ID of the connection.
265
259
pub ( crate ) conn_id : ConnectionId ,
@@ -281,25 +275,23 @@ impl<T: Listener> WriteTask<T> {
281
275
/// channel, and acts on them. It handles JSON messages, and going away
282
276
/// instructions. It also listens for the global shutdown signal from the
283
277
/// [`ServerShutdown`] struct.
278
+ ///
279
+ /// [`ServerShutdown`]: crate::pubsub::ServerShutdown
284
280
#[ instrument( skip( self ) , fields( conn_id = self . conn_id) ) ]
285
281
pub ( crate ) async fn task_future ( self ) {
286
282
let WriteTask {
287
- mut shutdown,
288
- mut gone,
283
+ tasks,
289
284
mut json,
290
285
mut connection,
291
286
..
292
287
} = self ;
293
- shutdown . mark_unchanged ( ) ;
288
+
294
289
loop {
295
290
select ! {
296
291
biased;
297
- _ = & mut gone => {
298
- debug!( "Connection has gone away" ) ;
299
- break ;
300
- }
301
- _ = shutdown. changed( ) => {
302
- debug!( "shutdown signal received" ) ;
292
+
293
+ _ = tasks. cancelled( ) => {
294
+ debug!( "Shutdown signal received" ) ;
303
295
break ;
304
296
}
305
297
json = json. recv( ) => {
@@ -317,7 +309,9 @@ impl<T: Listener> WriteTask<T> {
317
309
}
318
310
319
311
/// Spawn the future produced by [`Self::task_future`].
320
- pub ( crate ) fn spawn ( self ) -> JoinHandle < ( ) > {
321
- tokio:: spawn ( self . task_future ( ) )
312
+ pub ( crate ) fn spawn ( self ) -> tokio:: task:: JoinHandle < Option < ( ) > > {
313
+ let tasks = self . tasks . clone ( ) ;
314
+ let future = self . task_future ( ) ;
315
+ tasks. spawn_cancellable ( future)
322
316
}
323
317
}
0 commit comments