Skip to content

[feature] Add small CNN for grids 5x5 and up #4434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to
Note that PyTorch 1.6.0 or greater should be installed to use this feature; see
[the PyTorch website](https://pytorch.org/) for installation instructions. (#4335)
- The minimum supported version of TensorFlow was increased to 1.14.0. (#4411)
- A CNN (`vis_encode_type: match3`) for smaller grids, e.g. board games, has been added.
(#4434)

### Bug Fixes
#### com.unity.ml-agents (C#)
Expand Down
2 changes: 1 addition & 1 deletion docs/Training-Configuration-File.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ choice of the trainer (which we review on subsequent sections).
| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
| `network_settings -> vis_encoder_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. |
| `network_settings -> vis_encoder_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |


## Trainer-specific Configurations
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def as_dict(self):


class EncoderType(Enum):
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
RESNET = "resnet"
Expand Down
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ def test_visual_ppo(num_visual, use_discrete):


@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_visual_advanced_ppo(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,
num_vector=0,
step_size=0.5,
vis_obs_size=(36, 36, 3),
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
)
new_networksettings = attr.evolve(
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
Expand Down Expand Up @@ -271,15 +271,15 @@ def test_visual_sac(num_visual, use_discrete):


@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_visual_advanced_sac(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,
num_vector=0,
step_size=0.5,
vis_obs_size=(36, 36, 3),
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
)
new_networksettings = attr.evolve(
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
Expand Down
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def test_visual_ppo(num_visual, use_discrete):


@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_visual_advanced_ppo(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,
num_vector=0,
step_size=0.5,
vis_obs_size=(36, 36, 3),
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
)
new_networksettings = attr.evolve(
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
Expand Down Expand Up @@ -272,15 +272,15 @@ def test_visual_sac(num_visual, use_discrete):


@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_visual_advanced_sac(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,
num_vector=0,
step_size=0.5,
vis_obs_size=(36, 36, 3),
vis_obs_size=(5, 5, 5) if vis_encode_type == "match3" else (36, 36, 3),
)
new_networksettings = attr.evolve(
SAC_CONFIG.network_settings, vis_encode_type=EncoderType(vis_encode_type)
Expand Down
54 changes: 53 additions & 1 deletion ml-agents/mlagents/trainers/tf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
EncoderType.RESNET: 15,
Expand Down Expand Up @@ -211,7 +212,10 @@ def create_normalizer(vector_obs: tf.Tensor) -> NormalizerTensors:
dtype=tf.float32,
initializer=tf.ones_initializer(),
)
initialize_normalization, update_normalization = ModelUtils.create_normalizer_update(
(
initialize_normalization,
update_normalization,
) = ModelUtils.create_normalizer_update(
vector_obs, steps, running_mean, running_variance
)
return NormalizerTensors(
Expand Down Expand Up @@ -346,6 +350,53 @@ def create_visual_observation_encoder(
)
return hidden_flat

@staticmethod
def create_match3_visual_observation_encoder(
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
num_layers: int,
scope: str,
reuse: bool,
) -> tf.Tensor:
"""
Builds a CNN with the architecture used by King for Candy Crush. Optimized
for grid-shaped boards, such as with Match-3 games.
:param image_input: The placeholder for the image input to use.
:param h_size: Hidden layer size.
:param activation: What type of activation function to use for layers.
:param num_layers: number of hidden layers to create.
:param scope: The scope of the graph within which to create the ops.
:param reuse: Whether to re-use the weights within the same scope.
:return: List of hidden layer tensors.
"""
with tf.variable_scope(scope):
conv1 = tf.layers.conv2d(
image_input,
35,
kernel_size=[3, 3],
strides=[1, 1],
activation=tf.nn.elu,
reuse=reuse,
name="conv_1",
)
conv2 = tf.layers.conv2d(
conv1,
144,
kernel_size=[3, 3],
strides=[1, 1],
activation=tf.nn.elu,
reuse=reuse,
name="conv_2",
)
hidden = tf.layers.flatten(conv2)

with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = ModelUtils.create_vector_observation_encoder(
hidden, h_size, activation, num_layers, scope, reuse
)
return hidden_flat

@staticmethod
def create_nature_cnn_visual_observation_encoder(
image_input: tf.Tensor,
Expand Down Expand Up @@ -475,6 +526,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction:
EncoderType.SIMPLE: ModelUtils.create_visual_observation_encoder,
EncoderType.NATURE_CNN: ModelUtils.create_nature_cnn_visual_observation_encoder,
EncoderType.RESNET: ModelUtils.create_resnet_visual_observation_encoder,
EncoderType.MATCH3: ModelUtils.create_match3_visual_observation_encoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(
encoder_type, ModelUtils.create_visual_observation_encoder
Expand Down
37 changes: 37 additions & 0 deletions ml-agents/mlagents/trainers/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,43 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
self.normalizer.update(inputs)


class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor
https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning
"""

def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.h_size = output_size
conv_1_hw = conv_output_shape((height, width), 3, 1)
conv_2_hw = conv_output_shape(conv_1_hw, 3, 1)
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 144

self.conv_layers = nn.Sequential(
nn.Conv2d(initial_channels, 35, [3, 3], [1, 1]),
nn.LeakyReLU(),
nn.Conv2d(35, 144, [3, 3], [1, 1]),
nn.LeakyReLU(),
)
self.dense = nn.Sequential(
linear_layer(
self.final_flat,
self.h_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
),
nn.LeakyReLU(),
)

def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
hidden = self.conv_layers(visual_obs)
hidden = torch.reshape(hidden, (-1, self.final_flat))
return self.dense(hidden)


class SimpleVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
Expand All @@ -19,6 +20,7 @@ class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
EncoderType.RESNET: 15,
Expand Down Expand Up @@ -124,6 +126,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
EncoderType.SIMPLE: SimpleVisualEncoder,
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

Expand Down