diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
index d52e1f4bd9..89e7df5f43 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
@@ -109,6 +109,7 @@ MonoBehaviour:
Rows: 9
Columns: 8
NumCellTypes: 6
+ NumSpecialTypes: 2
RandomSeed: -1
--- !u!114 &3508723250470608013
MonoBehaviour:
@@ -135,7 +136,7 @@ MonoBehaviour:
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
- UseVectorObservations: 1
+ ObservationType: 0
--- !u!1 &3508723250774301855
GameObject:
m_ObjectHideFlags: 0
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
index 6151c76e10..4c0d545f35 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
@@ -140,6 +140,7 @@ MonoBehaviour:
Rows: 9
Columns: 8
NumCellTypes: 6
+ NumSpecialTypes: 2
RandomSeed: -1
--- !u!114 &2118285884327540683
MonoBehaviour:
@@ -166,4 +167,4 @@ MonoBehaviour:
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
- UseVectorObservations: 1
+ ObservationType: 0
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
index 4074069b95..6136339890 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
@@ -140,6 +140,7 @@ MonoBehaviour:
Rows: 9
Columns: 8
NumCellTypes: 6
+ NumSpecialTypes: 2
RandomSeed: -1
--- !u!114 &3019509692332007780
MonoBehaviour:
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
index ad62ef6699..38a8b998de 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
@@ -179,6 +179,11 @@ PrefabInstance:
propertyPath: m_Name
value: Match3VisualObs (3)
objectReference: {fileID: 0}
+ - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3}
--- !u!1 &327661542
@@ -611,6 +616,16 @@ PrefabInstance:
propertyPath: m_Name
value: Match3VisualObs (1)
objectReference: {fileID: 0}
+ - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
+ - target: {fileID: 3019509692332007790, guid: aaa471bd5e2014848a66917476671aed,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3}
--- !u!1001 &1278119417
@@ -685,6 +700,11 @@ PrefabInstance:
propertyPath: m_IsActive
value: 1
objectReference: {fileID: 0}
+ - target: {fileID: 2118285884327540673, guid: 6944ca02359f5427aa13c8551236a824,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 6944ca02359f5427aa13c8551236a824, type: 3}
--- !u!1001 &1479255359
@@ -966,6 +986,11 @@ PrefabInstance:
propertyPath: m_Name
value: Match3VisualObs (2)
objectReference: {fileID: 0}
+ - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3}
--- !u!1001 &2118285882709515366
@@ -1035,6 +1060,16 @@ PrefabInstance:
propertyPath: m_Name
value: Match3VectorObs
objectReference: {fileID: 0}
+ - target: {fileID: 2118285883905619929, guid: 6944ca02359f5427aa13c8551236a824,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
+ - target: {fileID: 2118285884327540673, guid: 6944ca02359f5427aa13c8551236a824,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
- target: {fileID: 2118285884327540680, guid: 6944ca02359f5427aa13c8551236a824,
type: 3}
propertyPath: UseVectorObservations
@@ -1109,5 +1144,10 @@ PrefabInstance:
propertyPath: m_Name
value: Match3VisualObs
objectReference: {fileID: 0}
+ - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed,
+ type: 3}
+ propertyPath: m_IsActive
+ value: 1
+ objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
index 2d6a1a1c88..b3cd10250b 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
@@ -10,14 +10,14 @@ public class Match3Board : AbstractBoard
public const int k_EmptyCell = -1;
- int[,] m_Cells;
+ (int, int)[,] m_Cells;
bool[,] m_Matched;
System.Random m_Random;
void Awake()
{
- m_Cells = new int[Columns, Rows];
+ m_Cells = new (int, int)[Columns, Rows];
m_Matched = new bool[Columns, Rows];
m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed);
@@ -42,7 +42,12 @@ public override bool MakeMove(Move move)
public override int GetCellType(int row, int col)
{
- return m_Cells[col, row];
+ return m_Cells[col, row].Item1;
+ }
+
+ public override int GetSpecialType(int row, int col)
+ {
+ return m_Cells[col, row].Item2;
}
public override bool IsMoveValid(Move m)
@@ -67,7 +72,7 @@ public bool MarkMatchedCells(int[,] cells = null)
var matchedRows = 0;
for (var iOffset = i; iOffset < Rows; iOffset++)
{
- if (m_Cells[j, i] != m_Cells[j, iOffset])
+ if (m_Cells[j, i].Item1 != m_Cells[j, iOffset].Item1)
{
break;
}
@@ -89,7 +94,7 @@ public bool MarkMatchedCells(int[,] cells = null)
var matchedCols = 0;
for (var jOffset = j; jOffset < Columns; jOffset++)
{
- if (m_Cells[j, i] != m_Cells[jOffset, i])
+ if (m_Cells[j, i].Item1 != m_Cells[jOffset, i].Item1)
{
break;
}
@@ -122,7 +127,7 @@ public int ClearMatchedCells()
if (m_Matched[j, i])
{
numMatchedCells++;
- m_Cells[j, i] = k_EmptyCell;
+ m_Cells[j, i] = (k_EmptyCell, 0);
}
}
}
@@ -141,7 +146,7 @@ public bool DropCells()
for (var readIndex = 0; readIndex < Rows; readIndex++)
{
m_Cells[j, writeIndex] = m_Cells[j, readIndex];
- if (m_Cells[j, readIndex] != k_EmptyCell)
+ if (m_Cells[j, readIndex].Item1 != k_EmptyCell)
{
writeIndex++;
}
@@ -152,7 +157,7 @@ public bool DropCells()
for (; writeIndex < Rows; writeIndex++)
{
madeChanges = true;
- m_Cells[j, writeIndex] = k_EmptyCell;
+ m_Cells[j, writeIndex] = (k_EmptyCell, 0);
}
}
@@ -166,10 +171,10 @@ public bool FillFromAbove()
{
for (var j = 0; j < Columns; j++)
{
- if (m_Cells[j, i] == k_EmptyCell)
+ if (m_Cells[j, i].Item1 == k_EmptyCell)
{
madeChanges = true;
- m_Cells[j, i] = m_Random.Next(0, NumCellTypes);
+ m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType());
}
}
}
@@ -177,7 +182,7 @@ public bool FillFromAbove()
return madeChanges;
}
- public int[,] Cells
+ public (int, int)[,] Cells
{
get { return m_Cells; }
}
@@ -194,7 +199,7 @@ public void InitRandom()
{
for (var j = 0; j < Columns; j++)
{
- m_Cells[j, i] = m_Random.Next(0, NumCellTypes);
+ m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType());
}
}
}
@@ -226,6 +231,30 @@ void ClearMarked()
}
}
+ int GetRandomCellType()
+ {
+ return m_Random.Next(0, NumCellTypes);
+ }
+
+ int GetRandomSpecialType()
+ {
+ // 1/N chance to get a type-2 special
+ // 2/N chance to get a type-1 special
+ // otherwise 0 (boring)
+ var N = 10;
+ var val = m_Random.Next(0, N);
+ if (val == 0)
+ {
+ return 2;
+ }
+
+ if (val <= 2)
+ {
+ return 1;
+ }
+
+ return 0;
+ }
}
}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs
index fc76cd858b..eb82ad9ad7 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs
@@ -38,7 +38,7 @@ void OnDrawGizmos()
{
for (var j = 0; j < board.Columns; j++)
{
- var value = board.Cells != null ? board.Cells[j, i] : Match3Board.k_EmptyCell;
+ var value = board.Cells != null ? board.GetCellType(i, j) : Match3Board.k_EmptyCell;
if (value >= 0 && value < s_Colors.Length)
{
Gizmos.color = s_Colors[value];
@@ -51,7 +51,25 @@ void OnDrawGizmos()
var pos = new Vector3(j, i, 0);
pos *= cubeSpacing;
- Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * Vector3.one);
+ var specialType = board.Cells != null ? board.GetSpecialType(i, j) : 0;
+ if (specialType == 2)
+ {
+ Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(1f, .5f, .5f));
+ Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, 1f, .5f));
+ Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, .5f, 1f));
+ }
+ else if (specialType == 1)
+ {
+ Gizmos.DrawSphere(transform.TransformPoint(pos), .5f * cubeSize);
+ }
+ else
+ {
+ Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * Vector3.one);
+ }
+
+
+
+
Gizmos.color = Color.yellow;
if (board.Matched != null && board.Matched[j, i])
diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx
index 63b105fb3c..4580450f04 100644
Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx differ
diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn
index 3c82e464ac..d0193e8da6 100644
Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn differ
diff --git a/com.unity.ml-agents.extensions/Documentation~/Match3.md b/com.unity.ml-agents.extensions/Documentation~/Match3.md
index 426de52f36..d8eca58e97 100644
--- a/com.unity.ml-agents.extensions/Documentation~/Match3.md
+++ b/com.unity.ml-agents.extensions/Documentation~/Match3.md
@@ -5,6 +5,7 @@ We provide some utilities to integrate ML-Agents with Match-3 games.
## AbstractBoard class
The `AbstractBoard` is the bridge between ML-Agents and your game. It allows ML-Agents to
* ask your game what the "color" of a cell is
+* ask whether the cell is a "special" piece type or not
* ask your game whether a move is allowed
* request that your game make a move
@@ -17,6 +18,11 @@ Returns the "color" of piece at the given row and column.
This should be between 0 and NumCellTypes-1 (inclusive).
The actual order of the values doesn't matter.
+#### `public abstract int GetSpecialType(int row, int col)`
+Returns the special type of the piece at the given row and column.
+This should be between 0 and NumSpecialTypes (inclusive).
+The actual order of the values doesn't matter.
+
#### `public abstract bool IsMoveValid(Move m)`
Check whether the particular `Move` is valid for the game.
The actual results will depend on the rules of the game, but we provide the `SimpleIsMoveValid()` method
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
index 546b795408..59d01e4790 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
@@ -5,13 +5,29 @@ namespace Unity.MLAgents.Extensions.Match3
{
public abstract class AbstractBoard : MonoBehaviour
{
+ ///
+ /// Number of rows on the board
+ ///
public int Rows;
+
+ ///
+ /// Number of columns on the board
+ ///
public int Columns;
+
+ ///
+ /// Maximum number of different types of cells (colors, pieces, etc).
+ ///
public int NumCellTypes;
+ ///
+ /// Maximum number of special types. This can be zero, in which case
+ /// all cells of the same type are assumed to be equivalent.
+ ///
+ public int NumSpecialTypes;
///
- /// Returns the "color" of piece at the given row and column.
+ /// Returns the "color" of the piece at the given row and column.
/// This should be between 0 and NumCellTypes-1 (inclusive).
/// The actual order of the values doesn't matter.
///
@@ -20,6 +36,16 @@ public abstract class AbstractBoard : MonoBehaviour
///
public abstract int GetCellType(int row, int col);
+ ///
+ /// Returns the special type of the piece at the given row and column.
+ /// This should be between 0 and NumSpecialTypes (inclusive).
+ /// The actual order of the values doesn't matter.
+ ///
+ ///
+ ///
+ ///
+ public abstract int GetSpecialType(int row, int col);
+
///
/// Check whether the particular Move is valid for the game.
/// The actual results will depend on the rules of the game, but we provide SimpleIsMoveValid()
@@ -38,8 +64,6 @@ public abstract class AbstractBoard : MonoBehaviour
///
public abstract bool MakeMove(Move m);
- // TODO handle "special" cell types?
-
public IEnumerable AllMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
index 00c74e027a..697fa04ce4 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
@@ -11,15 +11,23 @@ public enum Match3ObservationType
CompressedVisual
}
- public class Match3Sensor : ISensor
+ public class Match3Sensor : ISparseChannelSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
+ private int[] m_SparseChannelMapping;
private int m_Rows;
private int m_Columns;
private int m_NumCellTypes;
+ private int m_NumSpecialTypes;
+ private ISparseChannelSensor sparseChannelSensorImplementation;
+
+ private int SpecialTypeSize
+ {
+ get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; }
+ }
public Match3Sensor(AbstractBoard board, Match3ObservationType obsType)
{
@@ -27,11 +35,33 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType)
m_Rows = board.Rows;
m_Columns = board.Columns;
m_NumCellTypes = board.NumCellTypes;
+ m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
- new[] { m_Rows * m_Columns * m_NumCellTypes } :
- new[] { m_Rows, m_Columns, m_NumCellTypes };
+ new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
+ new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
+
+ // See comment in GetCompressedObservation()
+ var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
+ m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize];
+ // If we have 4 cell types and 2 special types (3 special size), we'd have
+ // [0, 1, 2, 3, -1, -1, 4, 5, 6]
+ for (var i = 0; i < m_NumCellTypes; i++)
+ {
+ m_SparseChannelMapping[i] = i;
+ }
+
+ for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++)
+ {
+ m_SparseChannelMapping[i] = -1;
+ }
+
+ for (var i = 0; i < SpecialTypeSize; i++)
+ {
+ m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes;
+ }
+
}
public int[] GetObservationShape()
@@ -63,6 +93,16 @@ public int Write(ObservationWriter writer)
writer[offset] = (i == val) ? 1.0f : 0.0f;
offset++;
}
+
+ if (m_NumSpecialTypes > 0)
+ {
+ var special = m_Board.GetSpecialType(r, c);
+ for (var i = 0; i < SpecialTypeSize; i++)
+ {
+ writer[offset] = (i == special) ? 1.0f : 0.0f;
+ offset++;
+ }
+ }
}
}
@@ -82,6 +122,16 @@ public int Write(ObservationWriter writer)
writer[r, c, i] = (i == val) ? 1.0f : 0.0f;
offset++;
}
+
+ if (m_NumSpecialTypes > 0)
+ {
+ var special = m_Board.GetSpecialType(r, c);
+ for (var i = 0; i < SpecialTypeSize; i++)
+ {
+ writer[offset] = (i == special) ? 1.0f : 0.0f;
+ offset++;
+ }
+ }
}
}
@@ -96,10 +146,23 @@ public byte[] GetCompressedObservation()
var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
var converter = new OneHotToTextureUtil(height, width);
var bytesOut = new List();
- var numImages = (m_NumCellTypes + 2) / 3;
- for (var i = 0; i < numImages; i++)
+
+ // Encode the cell types and special types as separate batches of PNGs
+ // This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could
+ // fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
+ // the special types). Note that we have to also implement the sparse channel mapping.
+ // Optimize this it later.
+ var numCellImages = (m_NumCellTypes + 2) / 3;
+ for (var i = 0; i < numCellImages; i++)
+ {
+ converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i);
+ bytesOut.AddRange(tempTexture.EncodeToPNG());
+ }
+
+ var numSpecialImages = (SpecialTypeSize + 2) / 3;
+ for (var i = 0; i < numSpecialImages; i++)
{
- converter.EncodeToTexture(m_Board, tempTexture, 3 * i);
+ converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}
@@ -127,6 +190,11 @@ public string GetName()
return "Match3 Sensor";
}
+ public int[] GetCompressedChannelMapping()
+ {
+ return m_SparseChannelMapping;
+ }
+
static void DestroyTexture(Texture2D texture)
{
if (Application.isEditor)
@@ -155,6 +223,9 @@ public class OneHotToTextureUtil
int m_Width;
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };
+ public delegate int GridValueProvider(int x, int y);
+
+
public OneHotToTextureUtil(int height, int width)
{
m_Colors = new Color[height * width];
@@ -162,7 +233,7 @@ public OneHotToTextureUtil(int height, int width)
m_Width = width;
}
- public void EncodeToTexture(AbstractBoard board, Texture2D texture, int channelOffset)
+ public void EncodeToTexture(GridValueProvider gridValueProvider, Texture2D texture, int channelOffset)
{
var i = 0;
// There's an implicit flip converting to PNG from texture, so make sure we
@@ -171,7 +242,7 @@ public void EncodeToTexture(AbstractBoard board, Texture2D texture, int channelO
{
for (var w = 0; w < m_Width; w++)
{
- int oneHotValue = board.GetCellType(h, w);
+ int oneHotValue = gridValueProvider(h, w);
if (oneHotValue < channelOffset || oneHotValue >= channelOffset + 3)
{
m_Colors[i++] = Color.black;
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
index 2e678e7657..db336fc59a 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
@@ -20,9 +20,10 @@ public override int[] GetObservationShape()
return System.Array.Empty();
}
+ var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return ObservationType == Match3ObservationType.Vector ?
- new[] { board.Rows * board.Columns * board.NumCellTypes } :
- new[] { board.Rows, board.Columns, board.NumCellTypes };
+ new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } :
+ new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize };
}
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
index 8af6012ef4..f60ec55d8d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
@@ -46,6 +46,12 @@ public override int GetCellType(int row, int col)
var character = m_Board[m_Board.Length - 1 - row][col];
return (int)(character - '0');
}
+
+ public override int GetSpecialType(int row, int col)
+ {
+ return 0;
+ }
+
}
public class AbstractBoardTests