@@ -4,12 +4,15 @@ use futures_util::{future::BoxFuture, stream::BoxStream};
4
4
use postgres:: { Client , NoTls } ;
5
5
use r2d2_postgres:: PostgresConnectionManager ;
6
6
use sqlx:: { postgres:: PgPoolOptions , Executor } ;
7
- use std:: { sync:: Arc , time:: Duration } ;
7
+ use std:: {
8
+ ops:: { Deref , DerefMut } ,
9
+ sync:: Arc ,
10
+ time:: Duration ,
11
+ } ;
8
12
use tokio:: runtime:: Runtime ;
9
13
use tracing:: debug;
10
14
11
15
pub type PoolClient = r2d2:: PooledConnection < PostgresConnectionManager < NoTls > > ;
12
- pub type AsyncPoolClient = sqlx:: pool:: PoolConnection < sqlx:: postgres:: Postgres > ;
13
16
14
17
const DEFAULT_SCHEMA : & str = "public" ;
15
18
@@ -20,14 +23,15 @@ pub struct Pool {
20
23
#[ cfg( not( test) ) ]
21
24
pool : r2d2:: Pool < PostgresConnectionManager < NoTls > > ,
22
25
async_pool : sqlx:: PgPool ,
26
+ runtime : Arc < Runtime > ,
23
27
metrics : Arc < InstanceMetrics > ,
24
28
max_size : u32 ,
25
29
}
26
30
27
31
impl Pool {
28
32
pub fn new (
29
33
config : & Config ,
30
- runtime : & Runtime ,
34
+ runtime : Arc < Runtime > ,
31
35
metrics : Arc < InstanceMetrics > ,
32
36
) -> Result < Pool , PoolError > {
33
37
debug ! (
@@ -39,7 +43,7 @@ impl Pool {
39
43
#[ cfg( test) ]
40
44
pub ( crate ) fn new_with_schema (
41
45
config : & Config ,
42
- runtime : & Runtime ,
46
+ runtime : Arc < Runtime > ,
43
47
metrics : Arc < InstanceMetrics > ,
44
48
schema : & str ,
45
49
) -> Result < Pool , PoolError > {
@@ -48,7 +52,7 @@ impl Pool {
48
52
49
53
fn new_inner (
50
54
config : & Config ,
51
- runtime : & Runtime ,
55
+ runtime : Arc < Runtime > ,
52
56
metrics : Arc < InstanceMetrics > ,
53
57
schema : & str ,
54
58
) -> Result < Pool , PoolError > {
@@ -109,6 +113,7 @@ impl Pool {
109
113
pool,
110
114
async_pool,
111
115
metrics,
116
+ runtime,
112
117
max_size : config. max_legacy_pool_size + config. max_pool_size ,
113
118
} )
114
119
}
@@ -139,7 +144,10 @@ impl Pool {
139
144
140
145
pub async fn get_async ( & self ) -> Result < AsyncPoolClient , PoolError > {
141
146
match self . async_pool . acquire ( ) . await {
142
- Ok ( conn) => Ok ( conn) ,
147
+ Ok ( conn) => Ok ( AsyncPoolClient {
148
+ inner : Some ( conn) ,
149
+ runtime : self . runtime . clone ( ) ,
150
+ } ) ,
143
151
Err ( err) => {
144
152
self . metrics . failed_db_connections . inc ( ) ;
145
153
Err ( PoolError :: AsyncClientError ( err) )
@@ -222,6 +230,36 @@ where
222
230
}
223
231
}
224
232
233
+ /// we wrap `sqlx::PoolConnection` so we can drop it in a sync context
234
+ /// and enter the runtime.
235
+ /// Otherwise dropping the PoolConnection will panic because it can't spawn a task.
236
+ #[ derive( Debug ) ]
237
+ pub struct AsyncPoolClient {
238
+ inner : Option < sqlx:: pool:: PoolConnection < sqlx:: postgres:: Postgres > > ,
239
+ runtime : Arc < Runtime > ,
240
+ }
241
+
242
+ impl Deref for AsyncPoolClient {
243
+ type Target = sqlx:: PgConnection ;
244
+
245
+ fn deref ( & self ) -> & Self :: Target {
246
+ self . inner . as_ref ( ) . unwrap ( )
247
+ }
248
+ }
249
+
250
+ impl DerefMut for AsyncPoolClient {
251
+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
252
+ self . inner . as_mut ( ) . unwrap ( )
253
+ }
254
+ }
255
+
256
+ impl Drop for AsyncPoolClient {
257
+ fn drop ( & mut self ) {
258
+ let _guard = self . runtime . enter ( ) ;
259
+ drop ( self . inner . take ( ) )
260
+ }
261
+ }
262
+
225
263
#[ derive( Debug ) ]
226
264
struct SetSchema {
227
265
schema : String ,
0 commit comments