diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index bd330d2..6b1e779 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -1,53 +1,63 @@ package org.toop; import org.toop.app.App; +import org.toop.framework.gameFramework.model.player.AbstractPlayer; +import org.toop.framework.gameFramework.model.player.Player; +import org.toop.game.games.reversi.BitboardReversi; +import org.toop.game.games.tictactoe.BitboardTicTacToe; +import org.toop.game.players.ArtificialPlayer; +import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI2; +import org.toop.game.players.ai.MCTSAI3; +import org.toop.game.players.ai.RandomAI; public final class Main { static void main(String[] args) { - App.run(args); - // testMCTS(10); + App.run(args); + // testMCTS(100); } - // Voor onderzoek - // private static void testMCTS(int games) { - // var random = new ArtificialPlayer<>(new RandomAI(), "Random AI"); - // var v1 = new ArtificialPlayer<>(new MCTSAI(10), "MCTS V1 AI"); - // var v2 = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); - // var v2_2 = new ArtificialPlayer<>(new MCTSAI2(100), "MCTS V2_2 AI"); - // var v3 = new ArtificialPlayer<>(new MCTSAI3(10), "MCTS V3 AI"); + private static void testMCTS(int games) { + var versions = new ArtificialPlayer[5]; + versions[0] = new ArtificialPlayer<>(new MCTSAI1(10), "MCTS V1 AI"); + versions[1] = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); + versions[2] = new ArtificialPlayer<>(new MCTSAI3(10, 10), "MCTS V3 AI"); + versions[3] = new ArtificialPlayer<>(new MCTSAI4(10, 10), "MCTS V4 AI"); + versions[4] = new ArtificialPlayer<>(new MCTSAI5(10, 10), "MCTS V5 AI"); - // testAI(games, new Player[]{ v1, v2 }); - // // testAI(games, new Player[]{ v1, v3 }); + for (int i = 2; i < versions.length; i++) { + for (int j = i + 1; j < versions.length; j++) { + final int playerIndex1 = i % versions.length; + final int playerIndex2 = j % versions.length; - // // testAI(games, new Player[]{ random, v3 }); - // // testAI(games, new Player[]{ v2, v3 }); - // testAI(games, new Player[]{ v2, v3 }); - // // testAI(games, new Player[]{ v3, v2 }); - // } + testAI(games, new Player[] { versions[playerIndex1], versions[playerIndex2]}); + } + } + } - // private static void testAI(int games, Player[] ais) { - // int wins = 0; - // int ties = 0; + private static void testAI(int games, Player[] ais) { + int wins = 0; + int ties = 0; - // for (int i = 0; i < games; i++) { - // final BitboardReversi match = new BitboardReversi(ais); + for (int i = 0; i < games; i++) { + final BitboardReversi match = new BitboardReversi(ais); - // while (!match.isTerminal()) { - // final int currentAI = match.getCurrentTurn(); - // final long move = ais[currentAI].getMove(match); + while (!match.isTerminal()) { + final int currentAI = match.getCurrentTurn(); + final long move = ais[currentAI].getMove(match); - // match.play(move); - // } + match.play(move); + } - // if (match.getWinner() < 0) { - // ties++; - // continue; - // } + if (match.getWinner() < 0) { + ties++; + continue; + } - // wins += match.getWinner() == 0? 1 : 0; - // } + wins += match.getWinner() == 0? 1 : 0; + } - // System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); - // System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); - // } + System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); + System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); + } } diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index afee3f2..1a6f1da 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -12,11 +12,13 @@ import org.toop.app.widget.complex.PlayerInfoWidget; import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; -import org.toop.framework.game.players.LocalPlayer; -import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI1; import org.toop.game.players.ai.MCTSAI2; import org.toop.game.players.ai.MCTSAI3; +import org.toop.game.players.ai.MCTSAI4; +import org.toop.game.players.ai.MCTSAI5; import org.toop.game.players.ai.MiniMaxAI; +import org.toop.game.players.ai.RandomAI; import org.toop.local.AppContext; import javafx.geometry.Pos; diff --git a/framework/src/main/java/org/toop/framework/game/BitboardGame.java b/framework/src/main/java/org/toop/framework/game/BitboardGame.java index 76997c5..26d7c1e 100644 --- a/framework/src/main/java/org/toop/framework/game/BitboardGame.java +++ b/framework/src/main/java/org/toop/framework/game/BitboardGame.java @@ -6,6 +6,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.Player; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; // There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy public abstract class BitboardGame implements TurnBasedGame { diff --git a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java index 102b94d..c093dd7 100644 --- a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java +++ b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java @@ -321,8 +321,51 @@ public class BitboardReversi extends BitboardGame { else if (blackCount > whiteCount){ return 0; } - else{ + else { return 1; } } + + @Override + public float rateMove(long move) { + final long corners = 0x8100000000000081L; + + if ((move & corners) != 0L) { + return 0.4f; + } + + final long xSquares = 0x0042000000004200L; + + if ((move & xSquares) != 0) { + return -0.4f; + } + + final long cSquares = 0x4281000000008142L; + + if ((move & cSquares) != 0) { + return -0.1f; + } + + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + long bestMove = 0L; + float bestMoveRate = Float.NEGATIVE_INFINITY; + + while (legalMoves != 0L) { + final long move = legalMoves & -legalMoves; + final float moveRate = rateMove(move); + + if (moveRate > bestMoveRate) { + bestMove = move; + bestMoveRate = moveRate; + } + + legalMoves &= ~move; + } + + return bestMove; + } } \ No newline at end of file diff --git a/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java b/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java index fa89efc..cf0f35a 100644 --- a/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java +++ b/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java @@ -110,4 +110,14 @@ public class BitboardTicTacToe extends BitboardGame { public BitboardTicTacToe deepCopy() { return new BitboardTicTacToe(this); } + + @Override + public float rateMove(long move) { + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + return legalMoves; + } } \ No newline at end of file diff --git a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java index 6f5f8a7..6cd5996 100644 --- a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java +++ b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java @@ -13,4 +13,7 @@ public interface TurnBasedGame extends DeepCopyable { PlayResult play(long move); PlayResult getState(); boolean isTerminal(); + + float rateMove(long move); + long heuristicMove(long legalMoves); } diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java deleted file mode 100644 index 8687846..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java +++ /dev/null @@ -1,193 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.GameState; -import org.toop.framework.gameFramework.model.game.PlayResult; -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.Random; - -public class MCTSAI extends AbstractAI { - private static class Node { - public TurnBasedGame state; - public long move; - - public Node parent; - - public int expanded; - public Node[] children; - - public int visits; - public float value; - - public Node(TurnBasedGame state, long move, Node parent) { - this.state = state; - this.move = move; - - this.parent = parent; - - this.expanded = 0; - this.children = new Node[Long.bitCount(state.getLegalMoves())]; - - this.visits = 0; - this.value = 0.0f; - } - - public Node(TurnBasedGame state) { - this(state, 0L, null); - } - - public boolean isFullyExpanded() { - return expanded >= children.length; - } - - float calculateUCT() { - float exploitation = visits <= 0? 0 : value / visits; - float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - int bestChildIndex = -1; - float bestScore = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float score = calculateUCT(); - - if (score > bestScore) { - bestChildIndex = i; - bestScore = score; - } - } - - return bestChildIndex >= 0? children[bestChildIndex] : this; - } - } - - private final int milliseconds; - - public MCTSAI(int milliseconds) { - this.milliseconds = milliseconds; - } - - public MCTSAI(MCTSAI other) { - this.milliseconds = other.milliseconds; - } - - @Override - public MCTSAI deepCopy() { - return new MCTSAI(this); - } - - @Override - public long getMove(TurnBasedGame game) { - Node root = new Node(game.deepCopy()); - - long endTime = System.currentTimeMillis() + milliseconds; - - while (System.currentTimeMillis() <= endTime) { - Node node = selection(root); - long legalMoves = node.state.getLegalMoves(); - - if (legalMoves != 0) { - node = expansion(node, legalMoves); - } - - float result = 0.0f; - - if (node.state.getLegalMoves() != 0) { - result = simulation(node.state, game.getCurrentTurn()); - } - - backPropagation(node, result); - } - - int mostVisitedIndex = -1; - int mostVisits = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisits) { - mostVisitedIndex = i; - mostVisits = root.children[i].visits; - } - } - - return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves()); - } - - private Node selection(Node node) { - while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) { - node = node.bestUCTChild(); - } - - return node; - } - - private Node expansion(Node node, long legalMoves) { - for (int i = 0; i < node.expanded; i++) { - legalMoves &= ~node.children[i].move; - } - - if (legalMoves == 0L) { - return node; - } - - long move = randomSetBit(legalMoves); - - TurnBasedGame copy = node.state.deepCopy(); - copy.play(move); - - Node newlyExpanded = new Node(copy, move, node); - - node.children[node.expanded] = newlyExpanded; - node.expanded++; - - return newlyExpanded; - } - - private float simulation(TurnBasedGame state, int playerIndex) { - TurnBasedGame copy = state.deepCopy(); - long legalMoves = copy.getLegalMoves(); - PlayResult result = null; - - while (legalMoves != 0) { - result = copy.play(randomSetBit(legalMoves)); - legalMoves = copy.getLegalMoves(); - } - - if (result.state() == GameState.WIN) { - if (result.player() == playerIndex) { - return 1.0f; - } - - return -1.0f; - } - - return -0.2f; - } - - private void backPropagation(Node node, float value) { - while (node != null) { - node.visits++; - node.value += value; - node = node.parent; - } - } - - public static long randomSetBit(long value) { - Random random = new Random(); - - int count = Long.bitCount(value); - int target = random.nextInt(count); - - while (true) { - int bit = Long.numberOfTrailingZeros(value); - if (target == 0) { - return 1L << bit; - } - value &= value - 1; - target--; - } - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java new file mode 100644 index 0000000..6621db4 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java @@ -0,0 +1,250 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.Random; + +public class MCTSAI1> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final Random random = new Random(); + + private final int milliseconds; + + public MCTSAI1(int milliseconds) { + this.milliseconds = milliseconds; + } + + public MCTSAI1(MCTSAI1 other) { + this.milliseconds = other.milliseconds; + } + + @Override + public MCTSAI1 deepCopy() { + return new MCTSAI1<>(this); + } + + @Override + public long getMove(T game) { + final Node root = new Node(game, null, 0L); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java index 872c693..380390c 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java @@ -35,6 +35,9 @@ public class MCTSAI2 extends AbstractAI { this.value = 0.0f; this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; } public Node(TurnBasedGame state) { @@ -46,6 +49,10 @@ public class MCTSAI2 extends AbstractAI { } public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + final float exploitation = value / visits; final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); @@ -63,24 +70,29 @@ public class MCTSAI2 extends AbstractAI { highestUCTChild = children[i]; highestUCT = childUCT; } - } return highestUCTChild; } } - private final Random random; + private static final Random random = new Random(); + private final int milliseconds; + private Node root; + public MCTSAI2(int milliseconds) { - this.random = new Random(); this.milliseconds = milliseconds; + + this.root = null; } public MCTSAI2(MCTSAI2 other) { this.random = other.random; this.milliseconds = other.milliseconds; + + this.root = other.root; } @Override @@ -90,7 +102,7 @@ public class MCTSAI2 extends AbstractAI { @Override public long getMove(TurnBasedGame game) { - final Node root = new Node(game, null, 0L); + root = findOrResetRoot(root, game); final long endTime = System.nanoTime() + milliseconds * 1_000_000L; @@ -102,8 +114,11 @@ public class MCTSAI2 extends AbstractAI { } final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; - return mostVisitedChild != null? mostVisitedChild.move : 0L; + root = findChildByMove(root, move); + + return move; } private Node mostVisitedChild(Node root) { @@ -120,8 +135,51 @@ public class MCTSAI2 extends AbstractAI { return mostVisitedChild; } + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { root = root.bestUCTChild(); } @@ -161,7 +219,9 @@ public class MCTSAI2 extends AbstractAI { if (copiedState.getWinner() == playerIndex) { return 1.0f; - } else if (copiedState.getWinner() >= 0) { + } + + if (copiedState.getWinner() >= 0) { return -1.0f; } @@ -173,11 +233,57 @@ public class MCTSAI2 extends AbstractAI { leaf.value += value; leaf.visits++; + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + value = -value; leaf = leaf.parent; } } + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + private long randomSetBit(long value) { if (0L == value) { return 0; diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java index fe892d1..368c89d 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java @@ -3,7 +3,13 @@ package org.toop.game.players.ai; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.AbstractAI; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; public class MCTSAI3 extends AbstractAI { private static class Node { @@ -20,6 +26,9 @@ public class MCTSAI3 extends AbstractAI { public float value; public int visits; + public boolean solved; + public float solvedValue; + public Node(TurnBasedGame state, Node parent, long move) { final long legalMoves = state.getLegalMoves(); @@ -35,6 +44,9 @@ public class MCTSAI3 extends AbstractAI { this.value = 0.0f; this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; } public Node(TurnBasedGame state) { @@ -46,6 +58,10 @@ public class MCTSAI3 extends AbstractAI { } public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + final float exploitation = value / visits; final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); @@ -63,23 +79,20 @@ public class MCTSAI3 extends AbstractAI { highestUCTChild = children[i]; highestUCT = childUCT; } - } return highestUCTChild; } } - private final Random random; + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - private Node root; private final int milliseconds; + private final int threads; - public MCTSAI3(int milliseconds) { - this.random = new Random(); - - this.root = null; + public MCTSAI3(int milliseconds, int threads) { this.milliseconds = milliseconds; + this.threads = threads; } public MCTSAI3(MCTSAI3 other) { @@ -87,6 +100,7 @@ public class MCTSAI3 extends AbstractAI { this.root = other.root; this.milliseconds = other.milliseconds; + this.threads = other.threads; } @Override @@ -95,24 +109,58 @@ public class MCTSAI3 extends AbstractAI { } @Override - public long getMove(TurnBasedGame game) { - detectRoot(game); - + public long getMove(T game) { + final ExecutorService pool = Executors.newFixedThreadPool(threads); final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + tasks.add(() -> { + final Node localRoot = new Node(game.deepCopy()); + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); } - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild != null? mostVisitedChild.move : 0L; + try { + final List> results = pool.invokeAll(tasks); - newRoot(move); + pool.shutdown(); - return move; + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } } private Node mostVisitedChild(Node root) { @@ -184,7 +232,7 @@ public class MCTSAI3 extends AbstractAI { } private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { root = root.bestUCTChild(); } @@ -224,7 +272,9 @@ public class MCTSAI3 extends AbstractAI { if (copiedState.getWinner() == playerIndex) { return 1.0f; - } else if (copiedState.getWinner() >= 0) { + } + + if (copiedState.getWinner() >= 0) { return -1.0f; } @@ -236,18 +286,64 @@ public class MCTSAI3 extends AbstractAI { leaf.value += value; leaf.visits++; + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + value = -value; leaf = leaf.parent; } } + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + private long randomSetBit(long value) { if (0L == value) { return 0; } final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); + final int randomBitCount = random.get().nextInt(bitCount); for (int i = 0; i < randomBitCount; i++) { value &= value - 1; diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java new file mode 100644 index 0000000..262d9ae --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java @@ -0,0 +1,359 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI4> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); + + private final int milliseconds; + private final int threads; + + private final Node[] threadRoots; + + public MCTSAI4(int milliseconds, int threads) { + this.milliseconds = milliseconds; + this.threads = threads; + + this.threadRoots = new Node[threads]; + } + + public MCTSAI4(MCTSAI4 other) { + this.milliseconds = other.milliseconds; + this.threads = other.threads; + + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI4 deepCopy() { + return new MCTSAI4<>(this); + } + + @Override + public long getMove(T game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java new file mode 100644 index 0000000..452eaf8 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java @@ -0,0 +1,371 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI5> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public float heuristic; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + + this.heuristic = state.rateMove(move); + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + final float bias = heuristic / visits; + + return exploitation + exploration + bias; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); + + private final int milliseconds; + private final int threads; + + private final Node[] threadRoots; + + public MCTSAI5(int milliseconds, int threads) { + this.milliseconds = milliseconds; + this.threads = threads; + + this.threadRoots = new Node[threads]; + } + + public MCTSAI5(MCTSAI5 other) { + this.milliseconds = other.milliseconds; + this.threads = other.threads; + + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI5 deepCopy() { + return new MCTSAI5<>(this); + } + + @Override + public long getMove(T game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + + long move = 0L; + + if (random.get().nextFloat() > 0.9f) { + move = copiedState.heuristicMove(legalMoves); + } else { + move = randomSetBit(legalMoves); + } + + copiedState.play(move); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file