1
+ use crate :: lua_preload;
1
2
use lazy_static:: lazy_static;
2
3
use mlua:: prelude:: LuaResult ;
3
4
use mlua:: { prelude:: * , Lua , UserData } ;
4
5
use std:: collections:: HashMap ;
6
+ use std:: sync:: Arc ;
5
7
use std:: sync:: Mutex ;
8
+ use std:: thread;
6
9
use std:: time:: Duration ;
10
+ use tokio:: runtime:: Builder ;
7
11
use tokio:: sync:: mpsc;
8
12
9
13
#[ derive( Clone , Debug ) ]
@@ -24,8 +28,8 @@ impl LuaChannel {
24
28
25
29
pub struct LuaChannelMgr {
26
30
channels : HashMap < String , LuaChannel > ,
27
- receivers : HashMap < i64 , mpsc:: Receiver < i64 > > ,
28
- senders : HashMap < i64 , mpsc:: Sender < i64 > > ,
31
+ receivers : HashMap < i64 , Arc < Mutex < mpsc:: Receiver < i64 > > > > ,
32
+ senders : HashMap < i64 , Arc < Mutex < mpsc:: Sender < i64 > > > > ,
29
33
id_counter : i64 ,
30
34
}
31
35
@@ -45,33 +49,25 @@ impl LuaChannelMgr {
45
49
self . id_counter += 1 ;
46
50
let channel = LuaChannel :: new ( name. clone ( ) , id) ;
47
51
self . channels . insert ( name. clone ( ) , channel) ;
48
- self . receivers . insert ( id, receiver) ;
49
- self . senders . insert ( id, sender) ;
52
+ self . receivers . insert ( id, Arc :: new ( Mutex :: new ( receiver) ) ) ;
53
+ self . senders . insert ( id, Arc :: new ( Mutex :: new ( sender) ) ) ;
50
54
}
51
55
52
56
pub fn get_channel ( & self , name : & str ) -> Option < LuaChannel > {
53
57
self . channels . get ( name) . cloned ( )
54
58
}
55
59
56
- pub async fn push ( & self , id : i64 , data : i64 ) -> Result < ( ) , mpsc:: error:: SendError < i64 > > {
57
- if let Some ( sender) = self . senders . get ( & id) {
58
- sender. send ( data) . await
59
- } else {
60
- Err ( mpsc:: error:: SendError ( data) )
61
- }
60
+ pub fn get_sender ( & self , id : i64 ) -> Option < Arc < Mutex < mpsc:: Sender < i64 > > > > {
61
+ self . senders . get ( & id) . cloned ( )
62
62
}
63
63
64
- pub async fn pop ( & mut self , id : i64 ) -> Option < i64 > {
65
- if let Some ( receiver) = self . receivers . get_mut ( & id) {
66
- receiver. recv ( ) . await
67
- } else {
68
- None
69
- }
64
+ pub fn get_receiver ( & self , id : i64 ) -> Option < Arc < Mutex < mpsc:: Receiver < i64 > > > > {
65
+ self . receivers . get ( & id) . cloned ( )
70
66
}
71
67
}
72
68
73
69
lazy_static ! {
74
- static ref luaChannelMgr : Mutex <LuaChannelMgr > = Mutex :: new( LuaChannelMgr :: new( ) ) ;
70
+ static ref ChannelMgr : Arc < Mutex <LuaChannelMgr >> = Arc :: new ( Mutex :: new( LuaChannelMgr :: new( ) ) ) ;
75
71
}
76
72
77
73
impl UserData for LuaChannel {
@@ -84,35 +80,47 @@ impl UserData for LuaChannel {
84
80
let id = this. id ;
85
81
let lua_seri_pack = lua. globals ( ) . get :: < LuaFunction > ( "lua_seri_pack" ) ?;
86
82
let ptr = lua_seri_pack. call :: < i64 > ( args) . unwrap ( ) ;
87
- luaChannelMgr. lock ( ) . unwrap ( ) . push ( id, ptr) . await . unwrap ( ) ;
83
+ let opt_sender = { ChannelMgr . lock ( ) . unwrap ( ) . get_sender ( id) } ;
84
+ if let Some ( sender) = opt_sender {
85
+ let sender = sender. lock ( ) . unwrap ( ) ;
86
+ sender. send ( ptr) . await . unwrap ( ) ;
87
+ }
88
88
Ok ( ( ) )
89
89
} ) ;
90
90
91
91
methods. add_async_method ( "pop" , |lua, this, ( ) | async move {
92
92
let id = this. id ;
93
- let data = luaChannelMgr. lock ( ) . unwrap ( ) . pop ( id) . await ;
94
- if let Some ( data) = data {
95
- let lua_seri_unpack = lua. globals ( ) . get :: < LuaFunction > ( "lua_seri_unpack" ) ?;
96
- let mut returns = lua_seri_unpack. call :: < mlua:: MultiValue > ( data) . unwrap ( ) ;
97
- returns. insert ( 0 , mlua:: Value :: Boolean ( true ) ) ;
98
- Ok ( returns)
99
- } else {
100
- let mut returns = mlua:: MultiValue :: new ( ) ;
101
- returns. insert ( 0 , mlua:: Value :: Boolean ( false ) ) ;
102
- Ok ( returns)
93
+ let opt_receiver = { ChannelMgr . lock ( ) . unwrap ( ) . get_receiver ( id) } ;
94
+ if let Some ( receiver) = opt_receiver {
95
+ let data = receiver. lock ( ) . unwrap ( ) . recv ( ) . await ;
96
+ if let Some ( data) = data {
97
+ let lua_seri_unpack = lua. globals ( ) . get :: < LuaFunction > ( "lua_seri_unpack" ) ?;
98
+ let mut returns = lua_seri_unpack. call :: < mlua:: MultiValue > ( data) . unwrap ( ) ;
99
+ returns. insert ( 0 , mlua:: Value :: Boolean ( true ) ) ;
100
+ return Ok ( returns) ;
101
+ }
103
102
}
103
+
104
+ let mut returns = mlua:: MultiValue :: new ( ) ;
105
+ returns. insert ( 0 , mlua:: Value :: Boolean ( false ) ) ;
106
+ Ok ( returns)
104
107
} ) ;
105
108
106
109
methods. add_async_method ( "bpop" , |lua, this, ( ) | async move {
107
110
let id = this. id ;
108
- let data = luaChannelMgr. lock ( ) . unwrap ( ) . pop ( id) . await ;
109
- if let Some ( data) = data {
110
- let lua_seri_unpack = lua. globals ( ) . get :: < LuaFunction > ( "lua_seri_unpack" ) ?;
111
- let returns = lua_seri_unpack. call :: < mlua:: MultiValue > ( data) . unwrap ( ) ;
112
- Ok ( returns)
113
- } else {
114
- Err ( mlua:: Error :: RuntimeError ( "Channel is closed" . to_string ( ) ) )
111
+ let opt_receiver = { ChannelMgr . lock ( ) . unwrap ( ) . get_receiver ( id) } ;
112
+ if let Some ( receiver) = opt_receiver {
113
+ let data = receiver. lock ( ) . unwrap ( ) . recv ( ) . await ;
114
+ if let Some ( data) = data {
115
+ let lua_seri_unpack = lua. globals ( ) . get :: < LuaFunction > ( "lua_seri_unpack" ) ?;
116
+ let mut returns = lua_seri_unpack. call :: < mlua:: MultiValue > ( data) . unwrap ( ) ;
117
+ returns. insert ( 0 , mlua:: Value :: Boolean ( true ) ) ;
118
+ return Ok ( returns) ;
119
+ }
115
120
}
121
+
122
+ let returns = mlua:: MultiValue :: new ( ) ;
123
+ Ok ( returns)
116
124
} ) ;
117
125
}
118
126
}
@@ -123,41 +131,46 @@ async fn bee_thread_sleep(_: Lua, time: u64) -> LuaResult<()> {
123
131
}
124
132
125
133
fn bee_thread_newchannel ( _: & Lua , name : String ) -> LuaResult < ( ) > {
126
- luaChannelMgr . lock ( ) . unwrap ( ) . new_channel ( name) ;
134
+ ChannelMgr . lock ( ) . unwrap ( ) . new_channel ( name) ;
127
135
Ok ( ( ) )
128
136
}
129
137
130
- fn bee_thread_channel ( _: & Lua , name : & str ) -> LuaResult < LuaChannel > {
131
- let mut mgr = luaChannelMgr . lock ( ) . unwrap ( ) ;
132
- if let Some ( channel) = mgr. get_channel ( name) {
138
+ fn bee_thread_channel ( _: & Lua , name : String ) -> LuaResult < LuaChannel > {
139
+ let mut mgr = ChannelMgr . lock ( ) . unwrap ( ) ;
140
+ if let Some ( channel) = mgr. get_channel ( & name) {
133
141
Ok ( channel)
134
142
} else {
135
143
mgr. new_channel ( name. to_string ( ) ) ;
136
- if let Some ( channel) = mgr. get_channel ( name) {
144
+ if let Some ( channel) = mgr. get_channel ( & name) {
137
145
return Ok ( channel) ;
138
146
}
139
147
Err ( mlua:: Error :: RuntimeError ( "Channel not found" . to_string ( ) ) )
140
148
}
141
149
}
142
150
143
- fn bee_thread_thread ( _: & Lua ) -> LuaResult < ( ) > {
151
+ fn bee_thread_thread ( _: & Lua , script : String ) -> LuaResult < ( ) > {
152
+ thread:: spawn ( move || {
153
+ let rt = Builder :: new_current_thread ( ) . enable_all ( ) . build ( ) . unwrap ( ) ;
154
+ rt. block_on ( async move {
155
+ let lua = unsafe { Lua :: unsafe_new ( ) } ;
156
+ if let Err ( e) = lua_preload:: lua_preload ( & lua) {
157
+ eprintln ! ( "Error during lua_preload: {:?}" , e) ;
158
+ return ;
159
+ }
160
+ lua. load ( script. as_bytes ( ) )
161
+ . call_async :: < ( ) > ( ( ) )
162
+ . await
163
+ . unwrap ( ) ;
164
+ } ) ;
165
+ } ) ;
144
166
Ok ( ( ) )
145
167
}
146
168
147
169
pub fn bee_thread ( lua : & Lua ) -> LuaResult < LuaTable > {
148
170
let thread = lua. create_table ( ) ?;
149
171
thread. set ( "sleep" , lua. create_async_function ( bee_thread_sleep) ?) ?;
150
- thread. set (
151
- "newchannel" ,
152
- lua. create_function ( |lua, name : String | Ok ( bee_thread_newchannel ( lua, name) ) ) ?,
153
- ) ?;
154
- thread. set (
155
- "channel" ,
156
- lua. create_function ( |lua, name : String | Ok ( bee_thread_channel ( lua, & name) ) ) ?,
157
- ) ?;
158
- thread. set (
159
- "thread" ,
160
- lua. create_function ( |lua, ( ) | Ok ( bee_thread_thread ( lua) ) ) ?,
161
- ) ?;
172
+ thread. set ( "newchannel" , lua. create_function ( bee_thread_newchannel) ?) ?;
173
+ thread. set ( "channel" , lua. create_function ( bee_thread_channel) ?) ?;
174
+ thread. set ( "thread" , lua. create_function ( bee_thread_thread) ?) ?;
162
175
Ok ( thread)
163
176
}
0 commit comments