diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
index 59d01e4790..365d99cc13 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
@@ -1,3 +1,4 @@
+using System;
using System.Collections.Generic;
using UnityEngine;
@@ -64,22 +65,43 @@ public abstract class AbstractBoard : MonoBehaviour
///
public abstract bool MakeMove(Move m);
+ ///
+ /// Return the total number of moves possible for the board.
+ ///
+ ///
+ public int NumMoves()
+ {
+ return Move.NumPotentialMoves(Rows, Columns);
+ }
+
+ ///
+ /// An optional callback for when the all moves are invalid. Ideally, the game state should
+ /// be changed before this happens, but this is a way to get notified if not.
+ ///
+ public Action OnNoValidMovesAction;
+
+ ///
+ /// Iterate through all Moves on the board.
+ ///
+ ///
public IEnumerable AllMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
- var numMoves = Move.NumPotentialMoves(Rows, Columns);
- for (var i = 0; i < numMoves; i++)
+ for (var i = 0; i < NumMoves(); i++)
{
yield return currentMove;
currentMove.Advance(Rows, Columns);
}
}
+ ///
+ /// Iterate through all valid Moves on the board.
+ ///
+ ///
public IEnumerable ValidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
- var numMoves = Move.NumPotentialMoves(Rows, Columns);
- for (var i = 0; i < numMoves; i++)
+ for (var i = 0; i < NumMoves(); i++)
{
if (IsMoveValid(currentMove))
{
@@ -89,11 +111,14 @@ public IEnumerable ValidMoves()
}
}
+ ///
+ /// Iterate through all invalid Moves on the board.
+ ///
+ ///
public IEnumerable InvalidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
- var numMoves = Move.NumPotentialMoves(Rows, Columns);
- for (var i = 0; i < numMoves; i++)
+ for (var i = 0; i < NumMoves(); i++)
{
if (!IsMoveValid(currentMove))
{
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
index 4c4485f643..8ad9c10b65 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
@@ -75,8 +75,31 @@ public void ResetData()
IEnumerable InvalidMoveIndices()
{
+ var numValidMoves = m_Board.NumMoves();
+
foreach (var move in m_Board.InvalidMoves())
{
+ numValidMoves--;
+ if (numValidMoves == 0)
+ {
+ // If all the moves are invalid and we mask all the actions out, this will cause an assert
+ // later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
+ // (or log a warning if not) and leave the last action unmasked. This isn't great, but
+ // an invalid move should be easier to handle than an exception..
+ if (m_Board.OnNoValidMovesAction != null)
+ {
+ m_Board.OnNoValidMovesAction();
+ }
+ else
+ {
+ Debug.LogWarning(
+ "No valid moves are available. The last action will be left unmasked, so " +
+ "an invalid move will be passed to AbstractBoard.MakeMove()."
+ );
+ }
+ // This means the last move won't be returned as an invalid index.
+ yield break;
+ }
yield return move.MoveIndex;
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
index f60ec55d8d..35bf4c405a 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
@@ -2,7 +2,6 @@
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
-using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgents.Extensions.Tests.Match3
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs
new file mode 100644
index 0000000000..32d1acc214
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs
@@ -0,0 +1,88 @@
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Extensions.Match3;
+using UnityEngine;
+
+namespace Unity.MLAgents.Extensions.Tests.Match3
+{
+ internal class SimpleBoard : AbstractBoard
+ {
+ public int LastMoveIndex;
+ public bool MovesAreValid = true;
+
+ public bool CallbackCalled;
+
+ public override int GetCellType(int row, int col)
+ {
+ return 0;
+ }
+
+ public override int GetSpecialType(int row, int col)
+ {
+ return 0;
+ }
+
+ public override bool IsMoveValid(Move m)
+ {
+ return MovesAreValid;
+ }
+
+ public override bool MakeMove(Move m)
+ {
+ LastMoveIndex = m.MoveIndex;
+ return MovesAreValid;
+ }
+
+ public void Callback()
+ {
+ CallbackCalled = true;
+ }
+ }
+
+ public class Match3ActuatorTests
+ {
+ [SetUp]
+ public void SetUp()
+ {
+ if (Academy.IsInitialized)
+ {
+ Academy.Instance.Dispose();
+ }
+ }
+
+ [TestCase(true)]
+ [TestCase(false)]
+ public void TestValidMoves(bool movesAreValid)
+ {
+ // Check that a board with no valid moves doesn't raise an exception.
+ var gameObj = new GameObject();
+ var board = gameObj.AddComponent();
+ var agent = gameObj.AddComponent();
+ gameObj.AddComponent();
+
+ board.Rows = 5;
+ board.Columns = 5;
+ board.NumCellTypes = 5;
+ board.NumSpecialTypes = 0;
+
+ board.MovesAreValid = movesAreValid;
+ board.OnNoValidMovesAction = board.Callback;
+ board.LastMoveIndex = -1;
+
+ agent.LazyInitialize();
+ agent.RequestDecision();
+ Academy.Instance.EnvironmentStep();
+
+ if (movesAreValid)
+ {
+ Assert.IsFalse(board.CallbackCalled);
+ }
+ else
+ {
+ Assert.IsTrue(board.CallbackCalled);
+ }
+ Assert.AreNotEqual(-1, board.LastMoveIndex);
+ }
+
+ }
+}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta
new file mode 100644
index 0000000000..3731b4e758
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 2edf24df24ac426085cb31a94d063683
+timeCreated: 1603392289
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs
index abd10bc858..dc94355794 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs
@@ -1,8 +1,4 @@
-using System;
-using System.Collections.Generic;
-using UnityEngine;
using NUnit.Framework;
-using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgents.Extensions.Tests.Match3