diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs index 59d01e4790..365d99cc13 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using UnityEngine; @@ -64,22 +65,43 @@ public abstract class AbstractBoard : MonoBehaviour /// public abstract bool MakeMove(Move m); + /// + /// Return the total number of moves possible for the board. + /// + /// + public int NumMoves() + { + return Move.NumPotentialMoves(Rows, Columns); + } + + /// + /// An optional callback for when the all moves are invalid. Ideally, the game state should + /// be changed before this happens, but this is a way to get notified if not. + /// + public Action OnNoValidMovesAction; + + /// + /// Iterate through all Moves on the board. + /// + /// public IEnumerable AllMoves() { var currentMove = Move.FromMoveIndex(0, Rows, Columns); - var numMoves = Move.NumPotentialMoves(Rows, Columns); - for (var i = 0; i < numMoves; i++) + for (var i = 0; i < NumMoves(); i++) { yield return currentMove; currentMove.Advance(Rows, Columns); } } + /// + /// Iterate through all valid Moves on the board. + /// + /// public IEnumerable ValidMoves() { var currentMove = Move.FromMoveIndex(0, Rows, Columns); - var numMoves = Move.NumPotentialMoves(Rows, Columns); - for (var i = 0; i < numMoves; i++) + for (var i = 0; i < NumMoves(); i++) { if (IsMoveValid(currentMove)) { @@ -89,11 +111,14 @@ public IEnumerable ValidMoves() } } + /// + /// Iterate through all invalid Moves on the board. + /// + /// public IEnumerable InvalidMoves() { var currentMove = Move.FromMoveIndex(0, Rows, Columns); - var numMoves = Move.NumPotentialMoves(Rows, Columns); - for (var i = 0; i < numMoves; i++) + for (var i = 0; i < NumMoves(); i++) { if (!IsMoveValid(currentMove)) { diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs index 4c4485f643..8ad9c10b65 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs @@ -75,8 +75,31 @@ public void ResetData() IEnumerable InvalidMoveIndices() { + var numValidMoves = m_Board.NumMoves(); + foreach (var move in m_Board.InvalidMoves()) { + numValidMoves--; + if (numValidMoves == 0) + { + // If all the moves are invalid and we mask all the actions out, this will cause an assert + // later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one, + // (or log a warning if not) and leave the last action unmasked. This isn't great, but + // an invalid move should be easier to handle than an exception.. + if (m_Board.OnNoValidMovesAction != null) + { + m_Board.OnNoValidMovesAction(); + } + else + { + Debug.LogWarning( + "No valid moves are available. The last action will be left unmasked, so " + + "an invalid move will be passed to AbstractBoard.MakeMove()." + ); + } + // This means the last move won't be returned as an invalid index. + yield break; + } yield return move.MoveIndex; } } diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs index f60ec55d8d..35bf4c405a 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using UnityEngine; using NUnit.Framework; -using Unity.MLAgents.Sensors; using Unity.MLAgents.Extensions.Match3; namespace Unity.MLAgents.Extensions.Tests.Match3 diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs new file mode 100644 index 0000000000..32d1acc214 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs @@ -0,0 +1,88 @@ +using NUnit.Framework; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Extensions.Match3; +using UnityEngine; + +namespace Unity.MLAgents.Extensions.Tests.Match3 +{ + internal class SimpleBoard : AbstractBoard + { + public int LastMoveIndex; + public bool MovesAreValid = true; + + public bool CallbackCalled; + + public override int GetCellType(int row, int col) + { + return 0; + } + + public override int GetSpecialType(int row, int col) + { + return 0; + } + + public override bool IsMoveValid(Move m) + { + return MovesAreValid; + } + + public override bool MakeMove(Move m) + { + LastMoveIndex = m.MoveIndex; + return MovesAreValid; + } + + public void Callback() + { + CallbackCalled = true; + } + } + + public class Match3ActuatorTests + { + [SetUp] + public void SetUp() + { + if (Academy.IsInitialized) + { + Academy.Instance.Dispose(); + } + } + + [TestCase(true)] + [TestCase(false)] + public void TestValidMoves(bool movesAreValid) + { + // Check that a board with no valid moves doesn't raise an exception. + var gameObj = new GameObject(); + var board = gameObj.AddComponent(); + var agent = gameObj.AddComponent(); + gameObj.AddComponent(); + + board.Rows = 5; + board.Columns = 5; + board.NumCellTypes = 5; + board.NumSpecialTypes = 0; + + board.MovesAreValid = movesAreValid; + board.OnNoValidMovesAction = board.Callback; + board.LastMoveIndex = -1; + + agent.LazyInitialize(); + agent.RequestDecision(); + Academy.Instance.EnvironmentStep(); + + if (movesAreValid) + { + Assert.IsFalse(board.CallbackCalled); + } + else + { + Assert.IsTrue(board.CallbackCalled); + } + Assert.AreNotEqual(-1, board.LastMoveIndex); + } + + } +} diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta new file mode 100644 index 0000000000..3731b4e758 --- /dev/null +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 2edf24df24ac426085cb31a94d063683 +timeCreated: 1603392289 \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs index abd10bc858..dc94355794 100644 --- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs +++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using UnityEngine; using NUnit.Framework; -using Unity.MLAgents.Sensors; using Unity.MLAgents.Extensions.Match3; namespace Unity.MLAgents.Extensions.Tests.Match3