Skip to content

Commit 5bdcf48

Browse files
committed
Merge remote-tracking branch 'origin/main' into dev-pooltool
2 parents ec4d6df + 540bdcb commit 5bdcf48

File tree

160 files changed

+13149
-987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

160 files changed

+13149
-987
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ log*
14211421
default*
14221422
events.*
14231423

1424-
# DI-engine special key
1424+
# LightZero special key
14251425
*default_logger.txt
14261426
*default_tb_logger
14271427
*evaluate.txt
@@ -1448,3 +1448,4 @@ events.*
14481448

14491449
# pooltool-specific stuff
14501450
!/assets/pooltool/**
1451+
lzero/mcts/ctree/ctree_alphazero/pybind11

README.md

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,24 +122,25 @@ LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of
122122

123123
The environments and algorithms currently supported by LightZero are shown in the table below:
124124

125-
| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero |
126-
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|
127-
| TicTacToe ||| 🔒 | 🔒 || 🔒 |
128-
| Gomoku ||| 🔒 | 🔒 || 🔒 |
129-
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |
130-
| 2048 | --- || 🔒 | 🔒 | 🔒 ||
131-
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
132-
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
133-
| CartPole | --- ||||||
134-
| Pendulum | --- ||||||
135-
| LunarLander | --- ||||||
136-
| BipedalWalker | --- ||||| 🔒 |
137-
| Atari | --- ||||||
138-
| MuJoCo | --- |||| 🔒 | 🔒 |
139-
| MiniGrid | --- |||| 🔒 | 🔒 |
140-
| Bsuite | --- |||| 🔒 | 🔒 |
141-
| Memory | --- |||| 🔒 | 🔒 |
142-
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |
125+
126+
| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
127+
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
128+
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
129+
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
130+
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
131+
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
132+
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
133+
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
134+
| CartPole | --- ||||||||
135+
| Pendulum | --- ||||||🔒|🔒 |
136+
| LunarLander | --- |||||||🔒 |
137+
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
138+
| Atari | --- ||||||||
139+
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
140+
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
141+
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
142+
| Memory | --- |||| 🔒 | 🔒 ||🔒 |
143+
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |🔒|🔒 |
143144

144145

145146
<sup>(1): "✔" means that the corresponding item is finished and well-tested.</sup>

README.zh.md

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,24 @@ LightZero 是基于 [PyTorch](https://pytorch.org/) 实现的 MCTS 算法库,
110110

111111
LightZero 目前支持的环境及算法如下表所示:
112112

113-
| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero |
114-
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|
115-
| TicTacToe ||| 🔒 | 🔒 || 🔒 |
116-
| Gomoku ||| 🔒 | 🔒 || 🔒 |
117-
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |
118-
| 2048 | --- || 🔒 | 🔒 | 🔒 ||
119-
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
120-
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
121-
| CartPole | --- ||||||
122-
| Pendulum | --- ||||||
123-
| LunarLander | --- ||||||
124-
| BipedalWalker | --- ||||| 🔒 |
125-
| Atari | --- ||||||
126-
| MuJoCo | --- |||| 🔒 | 🔒 |
127-
| MiniGrid | --- |||| 🔒 | 🔒 |
128-
| Bsuite | --- |||| 🔒 | 🔒 |
129-
| Memory | --- |||| 🔒 | 🔒 |
130-
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |
131-
113+
| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero | UniZero |ReZero |
114+
|---------------| -------- | ------ |-------------| ------------------ | ---------- |----------------|---------------|----------------|
115+
| TicTacToe ||| 🔒 | 🔒 || 🔒 ||🔒 |
116+
| Gomoku ||| 🔒 | 🔒 || 🔒 |||
117+
| Connect4 ||| 🔒 | 🔒 | 🔒 | 🔒 |||
118+
| 2048 | --- || 🔒 | 🔒 | 🔒 |||🔒 |
119+
| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
120+
| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |🔒|🔒 |
121+
| CartPole | --- ||||||||
122+
| Pendulum | --- ||||||🔒|🔒 |
123+
| LunarLander | --- |||||||🔒 |
124+
| BipedalWalker | --- ||||| 🔒 |🔒|🔒 |
125+
| Atari | --- ||||||||
126+
| MuJoCo | --- |||| 🔒 | 🔒 |🔒|🔒 |
127+
| MiniGrid | --- |||| 🔒 | 🔒 ||🔒 |
128+
| Bsuite | --- |||| 🔒 | 🔒 ||🔒 |
129+
| Memory | --- |||| 🔒 | 🔒 ||🔒 |
130+
| SumToThree (billiards) | --- | 🔒 | 🔒 || 🔒 | 🔒 |🔒|🔒 |
132131

133132
<sup>(1): "✔" 表示对应的项目已经完成并经过良好的测试。</sup>
134133

lzero/agent/alphazero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def train(
198198
new_data = sum(new_data, [])
199199

200200
if self.cfg.policy.update_per_collect is None:
201-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
201+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
202202
collected_transitions_num = len(new_data)
203-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
203+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
204204
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
205205

206206
# Learn policy from collected data

lzero/agent/efficientzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def train(
228228
# Collect data by default config n_sample/n_episode.
229229
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
230230
if self.cfg.policy.update_per_collect is None:
231-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
231+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
232232
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
233-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
233+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
234234
# save returned new_data collected by the collector
235235
replay_buffer.push_game_segments(new_data)
236236
# remove the oldest data if the replay buffer is full.

lzero/agent/gumbel_muzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def train(
228228
# Collect data by default config n_sample/n_episode.
229229
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
230230
if self.cfg.policy.update_per_collect is None:
231-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
231+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
232232
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
233-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
233+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
234234
# save returned new_data collected by the collector
235235
replay_buffer.push_game_segments(new_data)
236236
# remove the oldest data if the replay buffer is full.

lzero/agent/muzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def train(
228228
# Collect data by default config n_sample/n_episode.
229229
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
230230
if self.cfg.policy.update_per_collect is None:
231-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
231+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
232232
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
233-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
233+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
234234
# save returned new_data collected by the collector
235235
replay_buffer.push_game_segments(new_data)
236236
# remove the oldest data if the replay buffer is full.

lzero/agent/sampled_alphazero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def train(
198198
new_data = sum(new_data, [])
199199

200200
if self.cfg.policy.update_per_collect is None:
201-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
201+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
202202
collected_transitions_num = len(new_data)
203-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
203+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
204204
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
205205

206206
# Learn policy from collected data

lzero/agent/sampled_efficientzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def train(
228228
# Collect data by default config n_sample/n_episode.
229229
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
230230
if self.cfg.policy.update_per_collect is None:
231-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
231+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
232232
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
233-
update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
233+
update_per_collect = int(collected_transitions_num * self.cfg.policy.replay_ratio)
234234
# save returned new_data collected by the collector
235235
replay_buffer.push_game_segments(new_data)
236236
# remove the oldest data if the replay buffer is full.

lzero/entry/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from .train_alphazero import train_alphazero
21
from .eval_alphazero import eval_alphazero
3-
from .train_muzero import train_muzero
4-
from .train_muzero_with_reward_model import train_muzero_with_reward_model
52
from .eval_muzero import eval_muzero
63
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
7-
from .train_muzero_with_gym_env import train_muzero_with_gym_env
4+
from .train_alphazero import train_alphazero
5+
from .train_muzero import train_muzero
6+
from .train_muzero_with_gym_env import train_muzero_with_gym_env
7+
from .train_muzero_with_gym_env import train_muzero_with_gym_env
8+
from .train_muzero_with_reward_model import train_muzero_with_reward_model
9+
from .train_rezero import train_rezero
10+
from .train_unizero import train_unizero

lzero/entry/eval_muzero.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ding.utils import set_pkg_seed
1414
from ding.worker import BaseLearner
1515
from lzero.worker import MuZeroEvaluator
16+
from lzero.entry.utils import initialize_zeros_batch
1617

1718

1819
def eval_muzero(
@@ -25,7 +26,7 @@ def eval_muzero(
2526
) -> 'Policy': # noqa
2627
"""
2728
Overview:
28-
The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
29+
The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, StochasticMuZero, GumbelMuZero, UniZero, etc.
2930
Arguments:
3031
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
3132
``Tuple[dict, dict]`` type means [user_config, create_cfg].
@@ -38,8 +39,8 @@ def eval_muzero(
3839
- policy (:obj:`Policy`): Converged policy.
3940
"""
4041
cfg, create_cfg = input_cfg
41-
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
42-
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"
42+
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero', 'unizero'], \
43+
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero', 'unizero'"
4344

4445
if cfg.policy.cuda and torch.cuda.is_available():
4546
cfg.policy.device = 'cuda'

lzero/entry/train_alphazero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def train_alphazero(
119119
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
120120
new_data = sum(new_data, [])
121121
if cfg.policy.update_per_collect is None:
122-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
122+
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
123123
collected_transitions_num = len(new_data)
124-
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
124+
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
125125
replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
126126

127127
# Learn policy from collected data

0 commit comments

Comments
 (0)