diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab index d52e1f4bd9..89e7df5f43 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab @@ -109,6 +109,7 @@ MonoBehaviour: Rows: 9 Columns: 8 NumCellTypes: 6 + NumSpecialTypes: 2 RandomSeed: -1 --- !u!114 &3508723250470608013 MonoBehaviour: @@ -135,7 +136,7 @@ MonoBehaviour: m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} m_Name: m_EditorClassIdentifier: - UseVectorObservations: 1 + ObservationType: 0 --- !u!1 &3508723250774301855 GameObject: m_ObjectHideFlags: 0 diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab index 6151c76e10..4c0d545f35 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab @@ -140,6 +140,7 @@ MonoBehaviour: Rows: 9 Columns: 8 NumCellTypes: 6 + NumSpecialTypes: 2 RandomSeed: -1 --- !u!114 &2118285884327540683 MonoBehaviour: @@ -166,4 +167,4 @@ MonoBehaviour: m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} m_Name: m_EditorClassIdentifier: - UseVectorObservations: 1 + ObservationType: 0 diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab index 4074069b95..6136339890 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab @@ -140,6 +140,7 @@ MonoBehaviour: Rows: 9 Columns: 8 NumCellTypes: 6 + NumSpecialTypes: 2 RandomSeed: -1 --- !u!114 &3019509692332007780 MonoBehaviour: diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity index ad62ef6699..38a8b998de 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity +++ b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity @@ -179,6 +179,11 @@ PrefabInstance: propertyPath: m_Name value: Match3VisualObs (3) objectReference: {fileID: 0} + - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3} --- !u!1 &327661542 @@ -611,6 +616,16 @@ PrefabInstance: propertyPath: m_Name value: Match3VisualObs (1) objectReference: {fileID: 0} + - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 3019509692332007790, guid: aaa471bd5e2014848a66917476671aed, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3} --- !u!1001 &1278119417 @@ -685,6 +700,11 @@ PrefabInstance: propertyPath: m_IsActive value: 1 objectReference: {fileID: 0} + - target: {fileID: 2118285884327540673, guid: 6944ca02359f5427aa13c8551236a824, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: 6944ca02359f5427aa13c8551236a824, type: 3} --- !u!1001 &1479255359 @@ -966,6 +986,11 @@ PrefabInstance: propertyPath: m_Name value: Match3VisualObs (2) objectReference: {fileID: 0} + - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3} --- !u!1001 &2118285882709515366 @@ -1035,6 +1060,16 @@ PrefabInstance: propertyPath: m_Name value: Match3VectorObs objectReference: {fileID: 0} + - target: {fileID: 2118285883905619929, guid: 6944ca02359f5427aa13c8551236a824, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} + - target: {fileID: 2118285884327540673, guid: 6944ca02359f5427aa13c8551236a824, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} - target: {fileID: 2118285884327540680, guid: 6944ca02359f5427aa13c8551236a824, type: 3} propertyPath: UseVectorObservations @@ -1109,5 +1144,10 @@ PrefabInstance: propertyPath: m_Name value: Match3VisualObs objectReference: {fileID: 0} + - target: {fileID: 3019509691567202678, guid: aaa471bd5e2014848a66917476671aed, + type: 3} + propertyPath: m_IsActive + value: 1 + objectReference: {fileID: 0} m_RemovedComponents: [] m_SourcePrefab: {fileID: 100100000, guid: aaa471bd5e2014848a66917476671aed, type: 3} diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs index 2d6a1a1c88..b3cd10250b 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs @@ -10,14 +10,14 @@ public class Match3Board : AbstractBoard public const int k_EmptyCell = -1; - int[,] m_Cells; + (int, int)[,] m_Cells; bool[,] m_Matched; System.Random m_Random; void Awake() { - m_Cells = new int[Columns, Rows]; + m_Cells = new (int, int)[Columns, Rows]; m_Matched = new bool[Columns, Rows]; m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed); @@ -42,7 +42,12 @@ public override bool MakeMove(Move move) public override int GetCellType(int row, int col) { - return m_Cells[col, row]; + return m_Cells[col, row].Item1; + } + + public override int GetSpecialType(int row, int col) + { + return m_Cells[col, row].Item2; } public override bool IsMoveValid(Move m) @@ -67,7 +72,7 @@ public bool MarkMatchedCells(int[,] cells = null) var matchedRows = 0; for (var iOffset = i; iOffset < Rows; iOffset++) { - if (m_Cells[j, i] != m_Cells[j, iOffset]) + if (m_Cells[j, i].Item1 != m_Cells[j, iOffset].Item1) { break; } @@ -89,7 +94,7 @@ public bool MarkMatchedCells(int[,] cells = null) var matchedCols = 0; for (var jOffset = j; jOffset < Columns; jOffset++) { - if (m_Cells[j, i] != m_Cells[jOffset, i]) + if (m_Cells[j, i].Item1 != m_Cells[jOffset, i].Item1) { break; } @@ -122,7 +127,7 @@ public int ClearMatchedCells() if (m_Matched[j, i]) { numMatchedCells++; - m_Cells[j, i] = k_EmptyCell; + m_Cells[j, i] = (k_EmptyCell, 0); } } } @@ -141,7 +146,7 @@ public bool DropCells() for (var readIndex = 0; readIndex < Rows; readIndex++) { m_Cells[j, writeIndex] = m_Cells[j, readIndex]; - if (m_Cells[j, readIndex] != k_EmptyCell) + if (m_Cells[j, readIndex].Item1 != k_EmptyCell) { writeIndex++; } @@ -152,7 +157,7 @@ public bool DropCells() for (; writeIndex < Rows; writeIndex++) { madeChanges = true; - m_Cells[j, writeIndex] = k_EmptyCell; + m_Cells[j, writeIndex] = (k_EmptyCell, 0); } } @@ -166,10 +171,10 @@ public bool FillFromAbove() { for (var j = 0; j < Columns; j++) { - if (m_Cells[j, i] == k_EmptyCell) + if (m_Cells[j, i].Item1 == k_EmptyCell) { madeChanges = true; - m_Cells[j, i] = m_Random.Next(0, NumCellTypes); + m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType()); } } } @@ -177,7 +182,7 @@ public bool FillFromAbove() return madeChanges; } - public int[,] Cells + public (int, int)[,] Cells { get { return m_Cells; } } @@ -194,7 +199,7 @@ public void InitRandom() { for (var j = 0; j < Columns; j++) { - m_Cells[j, i] = m_Random.Next(0, NumCellTypes); + m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType()); } } } @@ -226,6 +231,30 @@ void ClearMarked() } } + int GetRandomCellType() + { + return m_Random.Next(0, NumCellTypes); + } + + int GetRandomSpecialType() + { + // 1/N chance to get a type-2 special + // 2/N chance to get a type-1 special + // otherwise 0 (boring) + var N = 10; + var val = m_Random.Next(0, N); + if (val == 0) + { + return 2; + } + + if (val <= 2) + { + return 1; + } + + return 0; + } } } diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs index fc76cd858b..eb82ad9ad7 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs @@ -38,7 +38,7 @@ void OnDrawGizmos() { for (var j = 0; j < board.Columns; j++) { - var value = board.Cells != null ? board.Cells[j, i] : Match3Board.k_EmptyCell; + var value = board.Cells != null ? board.GetCellType(i, j) : Match3Board.k_EmptyCell; if (value >= 0 && value < s_Colors.Length) { Gizmos.color = s_Colors[value]; @@ -51,7 +51,25 @@ void OnDrawGizmos() var pos = new Vector3(j, i, 0); pos *= cubeSpacing; - Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * Vector3.one); + var specialType = board.Cells != null ? board.GetSpecialType(i, j) : 0; + if (specialType == 2) + { + Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(1f, .5f, .5f)); + Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, 1f, .5f)); + Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, .5f, 1f)); + } + else if (specialType == 1) + { + Gizmos.DrawSphere(transform.TransformPoint(pos), .5f * cubeSize); + } + else + { + Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * Vector3.one); + } + + + + Gizmos.color = Color.yellow; if (board.Matched != null && board.Matched[j, i]) diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx index 63b105fb3c..4580450f04 100644 Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx differ diff --git a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn index 3c82e464ac..d0193e8da6 100644 Binary files a/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn and b/Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn differ diff --git a/com.unity.ml-agents.extensions/Documentation~/Match3.md b/com.unity.ml-agents.extensions/Documentation~/Match3.md index 426de52f36..d8eca58e97 100644 --- a/com.unity.ml-agents.extensions/Documentation~/Match3.md +++ b/com.unity.ml-agents.extensions/Documentation~/Match3.md @@ -5,6 +5,7 @@ We provide some utilities to integrate ML-Agents with Match-3 games. ## AbstractBoard class The `AbstractBoard` is the bridge between ML-Agents and your game. It allows ML-Agents to * ask your game what the "color" of a cell is +* ask whether the cell is a "special" piece type or not * ask your game whether a move is allowed * request that your game make a move @@ -17,6 +18,11 @@ Returns the "color" of piece at the given row and column. This should be between 0 and NumCellTypes-1 (inclusive). The actual order of the values doesn't matter. +#### `public abstract int GetSpecialType(int row, int col)` +Returns the special type of the piece at the given row and column. +This should be between 0 and NumSpecialTypes (inclusive). +The actual order of the values doesn't matter. + #### `public abstract bool IsMoveValid(Move m)` Check whether the particular `Move` is valid for the game. The actual results will depend on the rules of the game, but we provide the `SimpleIsMoveValid()` method diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs index 546b795408..59d01e4790 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs @@ -5,13 +5,29 @@ namespace Unity.MLAgents.Extensions.Match3 { public abstract class AbstractBoard : MonoBehaviour { + /// + /// Number of rows on the board + /// public int Rows; + + /// + /// Number of columns on the board + /// public int Columns; + + /// + /// Maximum number of different types of cells (colors, pieces, etc). + /// public int NumCellTypes; + /// + /// Maximum number of special types. This can be zero, in which case + /// all cells of the same type are assumed to be equivalent. + /// + public int NumSpecialTypes; /// - /// Returns the "color" of piece at the given row and column. + /// Returns the "color" of the piece at the given row and column. /// This should be between 0 and NumCellTypes-1 (inclusive). /// The actual order of the values doesn't matter. /// @@ -20,6 +36,16 @@ public abstract class AbstractBoard : MonoBehaviour /// public abstract int GetCellType(int row, int col); + /// + /// Returns the special type of the piece at the given row and column. + /// This should be between 0 and NumSpecialTypes (inclusive). + /// The actual order of the values doesn't matter. + /// + /// + /// + /// + public abstract int GetSpecialType(int row, int col); + /// /// Check whether the particular Move is valid for the game. /// The actual results will depend on the rules of the game, but we provide SimpleIsMoveValid() @@ -38,8 +64,6 @@ public abstract class AbstractBoard : MonoBehaviour /// public abstract bool MakeMove(Move m); - // TODO handle "special" cell types? - public IEnumerable AllMoves() { var currentMove = Move.FromMoveIndex(0, Rows, Columns); diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs index 00c74e027a..697fa04ce4 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs @@ -11,15 +11,23 @@ public enum Match3ObservationType CompressedVisual } - public class Match3Sensor : ISensor + public class Match3Sensor : ISparseChannelSensor { private Match3ObservationType m_ObservationType; private AbstractBoard m_Board; private int[] m_Shape; + private int[] m_SparseChannelMapping; private int m_Rows; private int m_Columns; private int m_NumCellTypes; + private int m_NumSpecialTypes; + private ISparseChannelSensor sparseChannelSensorImplementation; + + private int SpecialTypeSize + { + get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; } + } public Match3Sensor(AbstractBoard board, Match3ObservationType obsType) { @@ -27,11 +35,33 @@ public Match3Sensor(AbstractBoard board, Match3ObservationType obsType) m_Rows = board.Rows; m_Columns = board.Columns; m_NumCellTypes = board.NumCellTypes; + m_NumSpecialTypes = board.NumSpecialTypes; m_ObservationType = obsType; m_Shape = obsType == Match3ObservationType.Vector ? - new[] { m_Rows * m_Columns * m_NumCellTypes } : - new[] { m_Rows, m_Columns, m_NumCellTypes }; + new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } : + new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize }; + + // See comment in GetCompressedObservation() + var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3); + m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize]; + // If we have 4 cell types and 2 special types (3 special size), we'd have + // [0, 1, 2, 3, -1, -1, 4, 5, 6] + for (var i = 0; i < m_NumCellTypes; i++) + { + m_SparseChannelMapping[i] = i; + } + + for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++) + { + m_SparseChannelMapping[i] = -1; + } + + for (var i = 0; i < SpecialTypeSize; i++) + { + m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes; + } + } public int[] GetObservationShape() @@ -63,6 +93,16 @@ public int Write(ObservationWriter writer) writer[offset] = (i == val) ? 1.0f : 0.0f; offset++; } + + if (m_NumSpecialTypes > 0) + { + var special = m_Board.GetSpecialType(r, c); + for (var i = 0; i < SpecialTypeSize; i++) + { + writer[offset] = (i == special) ? 1.0f : 0.0f; + offset++; + } + } } } @@ -82,6 +122,16 @@ public int Write(ObservationWriter writer) writer[r, c, i] = (i == val) ? 1.0f : 0.0f; offset++; } + + if (m_NumSpecialTypes > 0) + { + var special = m_Board.GetSpecialType(r, c); + for (var i = 0; i < SpecialTypeSize; i++) + { + writer[offset] = (i == special) ? 1.0f : 0.0f; + offset++; + } + } } } @@ -96,10 +146,23 @@ public byte[] GetCompressedObservation() var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false); var converter = new OneHotToTextureUtil(height, width); var bytesOut = new List(); - var numImages = (m_NumCellTypes + 2) / 3; - for (var i = 0; i < numImages; i++) + + // Encode the cell types and special types as separate batches of PNGs + // This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could + // fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for + // the special types). Note that we have to also implement the sparse channel mapping. + // Optimize this it later. + var numCellImages = (m_NumCellTypes + 2) / 3; + for (var i = 0; i < numCellImages; i++) + { + converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i); + bytesOut.AddRange(tempTexture.EncodeToPNG()); + } + + var numSpecialImages = (SpecialTypeSize + 2) / 3; + for (var i = 0; i < numSpecialImages; i++) { - converter.EncodeToTexture(m_Board, tempTexture, 3 * i); + converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i); bytesOut.AddRange(tempTexture.EncodeToPNG()); } @@ -127,6 +190,11 @@ public string GetName() return "Match3 Sensor"; } + public int[] GetCompressedChannelMapping() + { + return m_SparseChannelMapping; + } + static void DestroyTexture(Texture2D texture) { if (Application.isEditor) @@ -155,6 +223,9 @@ public class OneHotToTextureUtil int m_Width; private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; + public delegate int GridValueProvider(int x, int y); + + public OneHotToTextureUtil(int height, int width) { m_Colors = new Color[height * width]; @@ -162,7 +233,7 @@ public OneHotToTextureUtil(int height, int width) m_Width = width; } - public void EncodeToTexture(AbstractBoard board, Texture2D texture, int channelOffset) + public void EncodeToTexture(GridValueProvider gridValueProvider, Texture2D texture, int channelOffset) { var i = 0; // There's an implicit flip converting to PNG from texture, so make sure we @@ -171,7 +242,7 @@ public void EncodeToTexture(AbstractBoard board, Texture2D texture, int channelO { for (var w = 0; w < m_Width; w++) { - int oneHotValue = board.GetCellType(h, w); + int oneHotValue = gridValueProvider(h, w); if (oneHotValue < channelOffset || oneHotValue >= channelOffset + 3) { m_Colors[i++] = Color.black; diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs index 2e678e7657..db336fc59a 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs @@ -20,9 +20,10 @@ public override int[] GetObservationShape() return System.Array.Empty(); } + var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1; return ObservationType == Match3ObservationType.Vector ? - new[] { board.Rows * board.Columns * board.NumCellTypes } : - new[] { board.Rows, board.Columns, board.NumCellTypes }; + new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } : + new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize }; } } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs index 8af6012ef4..f60ec55d8d 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs @@ -46,6 +46,12 @@ public override int GetCellType(int row, int col) var character = m_Board[m_Board.Length - 1 - row][col]; return (int)(character - '0'); } + + public override int GetSpecialType(int row, int col) + { + return 0; + } + } public class AbstractBoardTests