Skip to content

handle no valid moves more gracefully #4598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEngine;

Expand Down Expand Up @@ -64,22 +65,43 @@ public abstract class AbstractBoard : MonoBehaviour
/// <returns></returns>
public abstract bool MakeMove(Move m);

/// <summary>
/// Return the total number of moves possible for the board.
/// </summary>
/// <returns></returns>
public int NumMoves()
{
return Move.NumPotentialMoves(Rows, Columns);
}

/// <summary>
/// An optional callback for when the all moves are invalid. Ideally, the game state should
/// be changed before this happens, but this is a way to get notified if not.
/// </summary>
public Action OnNoValidMovesAction;

/// <summary>
/// Iterate through all Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> AllMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
var numMoves = Move.NumPotentialMoves(Rows, Columns);
for (var i = 0; i < numMoves; i++)
for (var i = 0; i < NumMoves(); i++)
{
yield return currentMove;
currentMove.Advance(Rows, Columns);
}
}

/// <summary>
/// Iterate through all valid Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> ValidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
var numMoves = Move.NumPotentialMoves(Rows, Columns);
for (var i = 0; i < numMoves; i++)
for (var i = 0; i < NumMoves(); i++)
{
if (IsMoveValid(currentMove))
{
Expand All @@ -89,11 +111,14 @@ public IEnumerable<Move> ValidMoves()
}
}

/// <summary>
/// Iterate through all invalid Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> InvalidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
var numMoves = Move.NumPotentialMoves(Rows, Columns);
for (var i = 0; i < numMoves; i++)
for (var i = 0; i < NumMoves(); i++)
{
if (!IsMoveValid(currentMove))
{
Expand Down
23 changes: 23 additions & 0 deletions com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,31 @@ public void ResetData()

IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();

foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
// (or log a warning if not) and leave the last action unmasked. This isn't great, but
// an invalid move should be easier to handle than an exception..
if (m_Board.OnNoValidMovesAction != null)
{
m_Board.OnNoValidMovesAction();
}
else
{
Debug.LogWarning(
"No valid moves are available. The last action will be left unmasked, so " +
"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
}
yield return move.MoveIndex;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Match3;

namespace Unity.MLAgents.Extensions.Tests.Match3
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;

namespace Unity.MLAgents.Extensions.Tests.Match3
{
internal class SimpleBoard : AbstractBoard
{
public int LastMoveIndex;
public bool MovesAreValid = true;

public bool CallbackCalled;

public override int GetCellType(int row, int col)
{
return 0;
}

public override int GetSpecialType(int row, int col)
{
return 0;
}

public override bool IsMoveValid(Move m)
{
return MovesAreValid;
}

public override bool MakeMove(Move m)
{
LastMoveIndex = m.MoveIndex;
return MovesAreValid;
}

public void Callback()
{
CallbackCalled = true;
}
}

public class Match3ActuatorTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}

[TestCase(true)]
[TestCase(false)]
public void TestValidMoves(bool movesAreValid)
{
// Check that a board with no valid moves doesn't raise an exception.
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var agent = gameObj.AddComponent<Agent>();
gameObj.AddComponent<Match3ActuatorComponent>();

board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;

board.MovesAreValid = movesAreValid;
board.OnNoValidMovesAction = board.Callback;
board.LastMoveIndex = -1;

agent.LazyInitialize();
agent.RequestDecision();
Academy.Instance.EnvironmentStep();

if (movesAreValid)
{
Assert.IsFalse(board.CallbackCalled);
}
else
{
Assert.IsTrue(board.CallbackCalled);
}
Assert.AreNotEqual(-1, board.LastMoveIndex);
}

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Match3;

namespace Unity.MLAgents.Extensions.Tests.Match3
Expand Down