Skip to content

Commit 44ed7e9

Browse files
committed
fix deadlock
1 parent 7ea2722 commit 44ed7e9

File tree

8 files changed

+195
-220
lines changed

8 files changed

+195
-220
lines changed

resources/testmain.lua

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,23 @@
1-
local fs = require "bee.filesystem"
2-
3-
local p = fs.path("resources/testmain.lua")
4-
5-
print(p / "aaaaa")
6-
7-
local thread = require "bee.thread"
8-
thread.newchannel("hello")
9-
local hello = thread.channel("hello")
10-
print(hello)
11-
hello:push("world", "yes", "no", 1, 2, 3)
12-
13-
local a, b, c, d, e, f = hello:pop()
14-
print(a, b, c, d, e, f)
15-
16-
local time = require "bee.time"
17-
print(time.time())
18-
print(time.monotonic())
19-
20-
local windows = require "bee.windows"
21-
for k, v in pairs(windows) do
22-
print(k, v)
23-
end
24-
windows.filemode(io.stdin, 'b')
25-
26-
local socket = require "bee.socket"
27-
local select = require "bee.select"
28-
local selector = select.create()
29-
thread.sleep(100)
30-
print("sleep complete")
31-
local fd = socket.create("tcp")
32-
fd:bind("127.0.0.1", 9988)
33-
print("bind complete")
34-
fd:listen()
35-
print("listen complete")
36-
-- local cfd = fd:accept()
37-
-- print(cfd)
38-
39-
selector:event_add(fd, 1, function()
40-
print("listener fd", fd)
41-
local cfd = fd:accept()
42-
print("accept a connection", cfd)
43-
end)
44-
1+
local thread = require 'bee.thread'
2+
3+
local d = thread.channel('taskpad')
4+
thread.thread([[
5+
local thread = require 'bee.thread'
6+
print("hello world1")
7+
local taskpad = thread.channel('taskpad')
8+
print("hello world2")
9+
local counter = 0
4510
while true do
46-
for func, event in selector:wait(1000) do
47-
print("func", func, "event", event)
48-
if func then
49-
func(event)
50-
end
51-
end
11+
print("hello world")
12+
taskpad:push(counter)
13+
counter = counter + 1
14+
thread.sleep(100)
5215
end
53-
54-
16+
]])
17+
-- thread.sleep(100)
18+
print("hello world")
19+
-- coroutine.yield()
20+
print("thread 1", d, "hello world")
21+
while true do
22+
print(d:bpop())
23+
end

src/bee/lua_socket.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use super::socket::lua_socket_pool::SOCKET_POOL;
33
use mlua::prelude::LuaResult;
44
use mlua::prelude::*;
55

6-
fn bee_socket_create(_: &Lua, protocol: String) -> LuaResult<LuaSocket> {
7-
let mut socket_pool = SOCKET_POOL.lock().unwrap();
6+
async fn bee_socket_create(_: Lua, protocol: String) -> LuaResult<LuaSocket> {
7+
let mut socket_pool = SOCKET_POOL.lock().await;
88
let socket = match protocol.as_str() {
99
"tcp" => socket_pool.create_socket(SocketType::Tcp).unwrap(),
1010
"unix" => socket_pool.create_socket(SocketType::Unix).unwrap(),
@@ -16,6 +16,6 @@ fn bee_socket_create(_: &Lua, protocol: String) -> LuaResult<LuaSocket> {
1616

1717
pub fn bee_socket(lua: &Lua) -> LuaResult<LuaTable> {
1818
let table = lua.create_table()?;
19-
table.set("create", lua.create_function(bee_socket_create)?)?;
19+
table.set("create", lua.create_async_function(bee_socket_create)?)?;
2020
Ok(table)
2121
}

src/bee/lua_thread.rs

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
use crate::lua_preload;
12
use lazy_static::lazy_static;
23
use mlua::prelude::LuaResult;
34
use mlua::{prelude::*, Lua, UserData};
45
use std::collections::HashMap;
6+
use std::sync::Arc;
57
use std::sync::Mutex;
8+
use std::thread;
69
use std::time::Duration;
10+
use tokio::runtime::Builder;
711
use tokio::sync::mpsc;
812

913
#[derive(Clone, Debug)]
@@ -24,8 +28,8 @@ impl LuaChannel {
2428

2529
pub struct LuaChannelMgr {
2630
channels: HashMap<String, LuaChannel>,
27-
receivers: HashMap<i64, mpsc::Receiver<i64>>,
28-
senders: HashMap<i64, mpsc::Sender<i64>>,
31+
receivers: HashMap<i64, Arc<Mutex<mpsc::Receiver<i64>>>>,
32+
senders: HashMap<i64, Arc<Mutex<mpsc::Sender<i64>>>>,
2933
id_counter: i64,
3034
}
3135

@@ -45,33 +49,25 @@ impl LuaChannelMgr {
4549
self.id_counter += 1;
4650
let channel = LuaChannel::new(name.clone(), id);
4751
self.channels.insert(name.clone(), channel);
48-
self.receivers.insert(id, receiver);
49-
self.senders.insert(id, sender);
52+
self.receivers.insert(id, Arc::new(Mutex::new(receiver)));
53+
self.senders.insert(id, Arc::new(Mutex::new(sender)));
5054
}
5155

5256
pub fn get_channel(&self, name: &str) -> Option<LuaChannel> {
5357
self.channels.get(name).cloned()
5458
}
5559

56-
pub async fn push(&self, id: i64, data: i64) -> Result<(), mpsc::error::SendError<i64>> {
57-
if let Some(sender) = self.senders.get(&id) {
58-
sender.send(data).await
59-
} else {
60-
Err(mpsc::error::SendError(data))
61-
}
60+
pub fn get_sender(&self, id: i64) -> Option<Arc<Mutex<mpsc::Sender<i64>>>> {
61+
self.senders.get(&id).cloned()
6262
}
6363

64-
pub async fn pop(&mut self, id: i64) -> Option<i64> {
65-
if let Some(receiver) = self.receivers.get_mut(&id) {
66-
receiver.recv().await
67-
} else {
68-
None
69-
}
64+
pub fn get_receiver(&self, id: i64) -> Option<Arc<Mutex<mpsc::Receiver<i64>>>> {
65+
self.receivers.get(&id).cloned()
7066
}
7167
}
7268

7369
lazy_static! {
74-
static ref luaChannelMgr: Mutex<LuaChannelMgr> = Mutex::new(LuaChannelMgr::new());
70+
static ref ChannelMgr: Arc<Mutex<LuaChannelMgr>> = Arc::new(Mutex::new(LuaChannelMgr::new()));
7571
}
7672

7773
impl UserData for LuaChannel {
@@ -84,35 +80,47 @@ impl UserData for LuaChannel {
8480
let id = this.id;
8581
let lua_seri_pack = lua.globals().get::<LuaFunction>("lua_seri_pack")?;
8682
let ptr = lua_seri_pack.call::<i64>(args).unwrap();
87-
luaChannelMgr.lock().unwrap().push(id, ptr).await.unwrap();
83+
let opt_sender = { ChannelMgr.lock().unwrap().get_sender(id) };
84+
if let Some(sender) = opt_sender {
85+
let sender = sender.lock().unwrap();
86+
sender.send(ptr).await.unwrap();
87+
}
8888
Ok(())
8989
});
9090

9191
methods.add_async_method("pop", |lua, this, ()| async move {
9292
let id = this.id;
93-
let data = luaChannelMgr.lock().unwrap().pop(id).await;
94-
if let Some(data) = data {
95-
let lua_seri_unpack = lua.globals().get::<LuaFunction>("lua_seri_unpack")?;
96-
let mut returns = lua_seri_unpack.call::<mlua::MultiValue>(data).unwrap();
97-
returns.insert(0, mlua::Value::Boolean(true));
98-
Ok(returns)
99-
} else {
100-
let mut returns = mlua::MultiValue::new();
101-
returns.insert(0, mlua::Value::Boolean(false));
102-
Ok(returns)
93+
let opt_receiver = { ChannelMgr.lock().unwrap().get_receiver(id) };
94+
if let Some(receiver) = opt_receiver {
95+
let data = receiver.lock().unwrap().recv().await;
96+
if let Some(data) = data {
97+
let lua_seri_unpack = lua.globals().get::<LuaFunction>("lua_seri_unpack")?;
98+
let mut returns = lua_seri_unpack.call::<mlua::MultiValue>(data).unwrap();
99+
returns.insert(0, mlua::Value::Boolean(true));
100+
return Ok(returns);
101+
}
103102
}
103+
104+
let mut returns = mlua::MultiValue::new();
105+
returns.insert(0, mlua::Value::Boolean(false));
106+
Ok(returns)
104107
});
105108

106109
methods.add_async_method("bpop", |lua, this, ()| async move {
107110
let id = this.id;
108-
let data = luaChannelMgr.lock().unwrap().pop(id).await;
109-
if let Some(data) = data {
110-
let lua_seri_unpack = lua.globals().get::<LuaFunction>("lua_seri_unpack")?;
111-
let returns = lua_seri_unpack.call::<mlua::MultiValue>(data).unwrap();
112-
Ok(returns)
113-
} else {
114-
Err(mlua::Error::RuntimeError("Channel is closed".to_string()))
111+
let opt_receiver = { ChannelMgr.lock().unwrap().get_receiver(id) };
112+
if let Some(receiver) = opt_receiver {
113+
let data = receiver.lock().unwrap().recv().await;
114+
if let Some(data) = data {
115+
let lua_seri_unpack = lua.globals().get::<LuaFunction>("lua_seri_unpack")?;
116+
let mut returns = lua_seri_unpack.call::<mlua::MultiValue>(data).unwrap();
117+
returns.insert(0, mlua::Value::Boolean(true));
118+
return Ok(returns);
119+
}
115120
}
121+
122+
let returns = mlua::MultiValue::new();
123+
Ok(returns)
116124
});
117125
}
118126
}
@@ -123,41 +131,46 @@ async fn bee_thread_sleep(_: Lua, time: u64) -> LuaResult<()> {
123131
}
124132

125133
fn bee_thread_newchannel(_: &Lua, name: String) -> LuaResult<()> {
126-
luaChannelMgr.lock().unwrap().new_channel(name);
134+
ChannelMgr.lock().unwrap().new_channel(name);
127135
Ok(())
128136
}
129137

130-
fn bee_thread_channel(_: &Lua, name: &str) -> LuaResult<LuaChannel> {
131-
let mut mgr = luaChannelMgr.lock().unwrap();
132-
if let Some(channel) = mgr.get_channel(name) {
138+
fn bee_thread_channel(_: &Lua, name: String) -> LuaResult<LuaChannel> {
139+
let mut mgr = ChannelMgr.lock().unwrap();
140+
if let Some(channel) = mgr.get_channel(&name) {
133141
Ok(channel)
134142
} else {
135143
mgr.new_channel(name.to_string());
136-
if let Some(channel) = mgr.get_channel(name) {
144+
if let Some(channel) = mgr.get_channel(&name) {
137145
return Ok(channel);
138146
}
139147
Err(mlua::Error::RuntimeError("Channel not found".to_string()))
140148
}
141149
}
142150

143-
fn bee_thread_thread(_: &Lua) -> LuaResult<()> {
151+
fn bee_thread_thread(_: &Lua, script: String) -> LuaResult<()> {
152+
thread::spawn(move || {
153+
let rt = Builder::new_current_thread().enable_all().build().unwrap();
154+
rt.block_on(async move {
155+
let lua = unsafe { Lua::unsafe_new() };
156+
if let Err(e) = lua_preload::lua_preload(&lua) {
157+
eprintln!("Error during lua_preload: {:?}", e);
158+
return;
159+
}
160+
lua.load(script.as_bytes())
161+
.call_async::<()>(())
162+
.await
163+
.unwrap();
164+
});
165+
});
144166
Ok(())
145167
}
146168

147169
pub fn bee_thread(lua: &Lua) -> LuaResult<LuaTable> {
148170
let thread = lua.create_table()?;
149171
thread.set("sleep", lua.create_async_function(bee_thread_sleep)?)?;
150-
thread.set(
151-
"newchannel",
152-
lua.create_function(|lua, name: String| Ok(bee_thread_newchannel(lua, name)))?,
153-
)?;
154-
thread.set(
155-
"channel",
156-
lua.create_function(|lua, name: String| Ok(bee_thread_channel(lua, &name)))?,
157-
)?;
158-
thread.set(
159-
"thread",
160-
lua.create_function(|lua, ()| Ok(bee_thread_thread(lua)))?,
161-
)?;
172+
thread.set("newchannel", lua.create_function(bee_thread_newchannel)?)?;
173+
thread.set("channel", lua.create_function(bee_thread_channel)?)?;
174+
thread.set("thread", lua.create_function(bee_thread_thread)?)?;
162175
Ok(thread)
163176
}

src/bee/socket/lua_select.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ impl LuaSelect {
3131
.map(|(socket_id, (callback, flag))| (*socket_id, callback.clone(), *flag))
3232
.collect();
3333
{
34-
let mut socket_pool = SOCKET_POOL.lock().unwrap();
34+
let mut socket_pool = SOCKET_POOL.lock().await;
3535
for (socket_id, callback, flag) in callbacks {
3636
if flag & 0x01 != 0 {
3737
select! {

0 commit comments

Comments
 (0)