From a6f5f2c8541d167d7559341458e0ddeecf7d1f93 Mon Sep 17 00:00:00 2001 From: ramollia <> Date: Thu, 15 Jan 2026 01:37:33 +0100 Subject: [PATCH] mcts v1, v2, v3, v4 done. v5 wip --- app/src/main/java/org/toop/Main.java | 36 +- .../app/widget/view/LocalMultiplayerView.java | 11 +- .../model/game/TurnBasedGame.java | 3 + .../main/java/org/toop/game/BitboardGame.java | 7 +- .../game/games/reversi/BitboardReversi.java | 59 ++- .../games/tictactoe/BitboardTicTacToe.java | 10 + .../java/org/toop/game/players/ai/MCTSAI.java | 193 --------- .../org/toop/game/players/ai/MCTSAI1.java | 250 ++++++++++++ .../org/toop/game/players/ai/MCTSAI2.java | 126 +++++- .../org/toop/game/players/ai/MCTSAI3.java | 199 ++++++---- .../org/toop/game/players/ai/MCTSAI4.java | 359 +++++++++++++++++ .../org/toop/game/players/ai/MCTSAI5.java | 371 ++++++++++++++++++ 12 files changed, 1302 insertions(+), 322 deletions(-) delete mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI.java create mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI1.java create mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI4.java create mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI5.java diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index 874276a..a698ba6 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -1,36 +1,38 @@ package org.toop; import org.toop.app.App; -import org.toop.framework.gameFramework.model.player.AbstractPlayer; import org.toop.framework.gameFramework.model.player.Player; import org.toop.game.games.reversi.BitboardReversi; import org.toop.game.games.tictactoe.BitboardTicTacToe; import org.toop.game.players.ArtificialPlayer; -import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI1; import org.toop.game.players.ai.MCTSAI2; import org.toop.game.players.ai.MCTSAI3; -import org.toop.game.players.ai.RandomAI; +import org.toop.game.players.ai.MCTSAI4; +import org.toop.game.players.ai.MCTSAI5; public final class Main { static void main(String[] args) { - // App.run(args); - testMCTS(10); + App.run(args); + // testMCTS(100); } private static void testMCTS(int games) { - var random = new ArtificialPlayer<>(new RandomAI(), "Random AI"); - var v1 = new ArtificialPlayer<>(new MCTSAI(10), "MCTS V1 AI"); - var v2 = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); - var v2_2 = new ArtificialPlayer<>(new MCTSAI2(100), "MCTS V2_2 AI"); - var v3 = new ArtificialPlayer<>(new MCTSAI3(10), "MCTS V3 AI"); + var versions = new ArtificialPlayer[5]; + versions[0] = new ArtificialPlayer<>(new MCTSAI1(10), "MCTS V1 AI"); + versions[1] = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); + versions[2] = new ArtificialPlayer<>(new MCTSAI3(10, 10), "MCTS V3 AI"); + versions[3] = new ArtificialPlayer<>(new MCTSAI4(10, 10), "MCTS V4 AI"); + versions[4] = new ArtificialPlayer<>(new MCTSAI5(10, 10), "MCTS V5 AI"); - testAI(games, new Player[]{ v1, v2 }); - // testAI(games, new Player[]{ v1, v3 }); + for (int i = 2; i < versions.length; i++) { + for (int j = i + 1; j < versions.length; j++) { + final int playerIndex1 = i % versions.length; + final int playerIndex2 = j % versions.length; - // testAI(games, new Player[]{ random, v3 }); - // testAI(games, new Player[]{ v2, v3 }); - testAI(games, new Player[]{ v2, v3 }); - // testAI(games, new Player[]{ v3, v2 }); + testAI(games, new Player[] { versions[playerIndex1], versions[playerIndex2]}); + } + } } private static void testAI(int games, Player[] ais) { @@ -58,4 +60,4 @@ public final class Main { System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); } -} +} \ No newline at end of file diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index 3e9e675..cfbf43f 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -15,9 +15,11 @@ import org.toop.app.widget.complex.PlayerInfoWidget; import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; -import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI1; import org.toop.game.players.ai.MCTSAI2; import org.toop.game.players.ai.MCTSAI3; +import org.toop.game.players.ai.MCTSAI4; +import org.toop.game.players.ai.MCTSAI5; import org.toop.game.players.ai.MiniMaxAI; import org.toop.game.players.ai.RandomAI; import org.toop.local.AppContext; @@ -57,7 +59,7 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer<>(information.players[0].name); } else { - players[0] = new ArtificialPlayer<>(new MCTSAI(100), "MCTS AI"); + players[0] = new ArtificialPlayer<>(new MCTSAI1(100), "MCTS AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer<>(information.players[1].name); @@ -85,13 +87,12 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer<>(information.players[0].name); } else { - // players[0] = new ArtificialPlayer<>(new RandomAI(), "Random AI"); - players[0] = new ArtificialPlayer<>(new MCTSAI3(50), "MCTS V3 AI"); + players[0] = new ArtificialPlayer<>(new MCTSAI4(100, 3), "MCTS V4 AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer<>(information.players[1].name); } else { - players[1] = new ArtificialPlayer<>(new MCTSAI2(50), "MCTS V2 AI"); + players[1] = new ArtificialPlayer<>(new MCTSAI5(100, 3), "MCTS V5 AI"); } if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { new ShowEnableTutorialWidget( diff --git a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java index 41f9b8c..65c31d7 100644 --- a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java +++ b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java @@ -7,4 +7,7 @@ public interface TurnBasedGame> extends Playable, Dee PlayResult getState(); boolean isTerminal(); + + float rateMove(long move); + long heuristicMove(long legalMoves); } diff --git a/game/src/main/java/org/toop/game/BitboardGame.java b/game/src/main/java/org/toop/game/BitboardGame.java index 8e708b5..33f5453 100644 --- a/game/src/main/java/org/toop/game/BitboardGame.java +++ b/game/src/main/java/org/toop/game/BitboardGame.java @@ -6,7 +6,6 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.Player; import java.util.Arrays; -import java.util.concurrent.atomic.AtomicInteger; // There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy public abstract class BitboardGame> implements TurnBasedGame { @@ -19,7 +18,7 @@ public abstract class BitboardGame> implements TurnBas // long is 64 bits. Every game has a limit of 64 cells maximum. private final long[] playerBitboard; - private int currentTurn = 0; + protected int currentTurn = 0; public BitboardGame(int columnSize, int rowSize, int playerCount, Player[] players) { this.columnSize = columnSize; @@ -82,10 +81,6 @@ public abstract class BitboardGame> implements TurnBas return (currentTurn + 1) % playerBitboard.length; } - public Player getCurrentPlayer(){ - return players[getCurrentPlayerIndex()]; - } - @Override public PlayResult getState() { return state; diff --git a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java index aa8d5b8..ed7283c 100644 --- a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java +++ b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java @@ -6,9 +6,6 @@ import org.toop.framework.gameFramework.model.player.Player; import org.toop.game.BitboardGame; public class BitboardReversi extends BitboardGame { - - public record Score(int black, int white) {} - private final long notAFile = 0xfefefefefefefefeL; private final long notHFile = 0x7f7f7f7f7f7f7f7fL; @@ -253,7 +250,9 @@ public class BitboardReversi extends BitboardGame { } @Override - public BitboardReversi deepCopy() {return new BitboardReversi(this);} + public BitboardReversi deepCopy() { + return new BitboardReversi(this); + } public PlayResult play(long move) { final long flips = getFlips(move); @@ -296,13 +295,6 @@ public class BitboardReversi extends BitboardGame { return state; } - public Score getScore() { - return new Score( - Long.bitCount(getPlayerBitboard(0)), - Long.bitCount(getPlayerBitboard(1)) - ); - } - public int getWinner(){ final long black = getPlayerBitboard(0); final long white = getPlayerBitboard(1); @@ -316,8 +308,51 @@ public class BitboardReversi extends BitboardGame { else if (blackCount > whiteCount){ return 0; } - else{ + else { return 1; } } + + @Override + public float rateMove(long move) { + final long corners = 0x8100000000000081L; + + if ((move & corners) != 0L) { + return 0.4f; + } + + final long xSquares = 0x0042000000004200L; + + if ((move & xSquares) != 0) { + return -0.4f; + } + + final long cSquares = 0x4281000000008142L; + + if ((move & cSquares) != 0) { + return -0.1f; + } + + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + long bestMove = 0L; + float bestMoveRate = Float.NEGATIVE_INFINITY; + + while (legalMoves != 0L) { + final long move = legalMoves & -legalMoves; + final float moveRate = rateMove(move); + + if (moveRate > bestMoveRate) { + bestMove = move; + bestMoveRate = moveRate; + } + + legalMoves &= ~move; + } + + return bestMove; + } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java b/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java index e06eff2..f0d945b 100644 --- a/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java +++ b/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java @@ -104,4 +104,14 @@ public class BitboardTicTacToe extends BitboardGame { public BitboardTicTacToe deepCopy() { return new BitboardTicTacToe(this); } + + @Override + public float rateMove(long move) { + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + return legalMoves; + } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java deleted file mode 100644 index 7a30caa..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java +++ /dev/null @@ -1,193 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.GameState; -import org.toop.framework.gameFramework.model.game.PlayResult; -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.Random; - -public class MCTSAI> extends AbstractAI { - private static class Node { - public TurnBasedGame state; - public long move; - - public Node parent; - - public int expanded; - public Node[] children; - - public int visits; - public float value; - - public Node(TurnBasedGame state, long move, Node parent) { - this.state = state; - this.move = move; - - this.parent = parent; - - this.expanded = 0; - this.children = new Node[Long.bitCount(state.getLegalMoves())]; - - this.visits = 0; - this.value = 0.0f; - } - - public Node(TurnBasedGame state) { - this(state, 0L, null); - } - - public boolean isFullyExpanded() { - return expanded >= children.length; - } - - float calculateUCT() { - float exploitation = visits <= 0? 0 : value / visits; - float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - int bestChildIndex = -1; - float bestScore = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float score = calculateUCT(); - - if (score > bestScore) { - bestChildIndex = i; - bestScore = score; - } - } - - return bestChildIndex >= 0? children[bestChildIndex] : this; - } - } - - private final int milliseconds; - - public MCTSAI(int milliseconds) { - this.milliseconds = milliseconds; - } - - public MCTSAI(MCTSAI other) { - this.milliseconds = other.milliseconds; - } - - @Override - public MCTSAI deepCopy() { - return new MCTSAI<>(this); - } - - @Override - public long getMove(T game) { - Node root = new Node(game.deepCopy()); - - long endTime = System.currentTimeMillis() + milliseconds; - - while (System.currentTimeMillis() <= endTime) { - Node node = selection(root); - long legalMoves = node.state.getLegalMoves(); - - if (legalMoves != 0) { - node = expansion(node, legalMoves); - } - - float result = 0.0f; - - if (node.state.getLegalMoves() != 0) { - result = simulation(node.state, game.getCurrentTurn()); - } - - backPropagation(node, result); - } - - int mostVisitedIndex = -1; - int mostVisits = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisits) { - mostVisitedIndex = i; - mostVisits = root.children[i].visits; - } - } - - return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves()); - } - - private Node selection(Node node) { - while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) { - node = node.bestUCTChild(); - } - - return node; - } - - private Node expansion(Node node, long legalMoves) { - for (int i = 0; i < node.expanded; i++) { - legalMoves &= ~node.children[i].move; - } - - if (legalMoves == 0L) { - return node; - } - - long move = randomSetBit(legalMoves); - - TurnBasedGame copy = node.state.deepCopy(); - copy.play(move); - - Node newlyExpanded = new Node(copy, move, node); - - node.children[node.expanded] = newlyExpanded; - node.expanded++; - - return newlyExpanded; - } - - private float simulation(TurnBasedGame state, int playerIndex) { - TurnBasedGame copy = state.deepCopy(); - long legalMoves = copy.getLegalMoves(); - PlayResult result = null; - - while (legalMoves != 0) { - result = copy.play(randomSetBit(legalMoves)); - legalMoves = copy.getLegalMoves(); - } - - if (result.state() == GameState.WIN) { - if (result.player() == playerIndex) { - return 1.0f; - } - - return -1.0f; - } - - return -0.2f; - } - - private void backPropagation(Node node, float value) { - while (node != null) { - node.visits++; - node.value += value; - node = node.parent; - } - } - - public static long randomSetBit(long value) { - Random random = new Random(); - - int count = Long.bitCount(value); - int target = random.nextInt(count); - - while (true) { - int bit = Long.numberOfTrailingZeros(value); - if (target == 0) { - return 1L << bit; - } - value &= value - 1; - target--; - } - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java new file mode 100644 index 0000000..6621db4 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java @@ -0,0 +1,250 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.Random; + +public class MCTSAI1> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final Random random = new Random(); + + private final int milliseconds; + + public MCTSAI1(int milliseconds) { + this.milliseconds = milliseconds; + } + + public MCTSAI1(MCTSAI1 other) { + this.milliseconds = other.milliseconds; + } + + @Override + public MCTSAI1 deepCopy() { + return new MCTSAI1<>(this); + } + + @Override + public long getMove(T game) { + final Node root = new Node(game, null, 0L); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java index bde88b2..8c616a6 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java @@ -20,6 +20,9 @@ public class MCTSAI2> extends AbstractAI { public float value; public int visits; + public boolean solved; + public float solvedValue; + public Node(TurnBasedGame state, Node parent, long move) { final long legalMoves = state.getLegalMoves(); @@ -35,6 +38,9 @@ public class MCTSAI2> extends AbstractAI { this.value = 0.0f; this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; } public Node(TurnBasedGame state) { @@ -46,6 +52,10 @@ public class MCTSAI2> extends AbstractAI { } public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + final float exploitation = value / visits; final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); @@ -63,24 +73,28 @@ public class MCTSAI2> extends AbstractAI { highestUCTChild = children[i]; highestUCT = childUCT; } - } return highestUCTChild; } } - private final Random random; + private static final Random random = new Random(); + private final int milliseconds; + private Node root; + public MCTSAI2(int milliseconds) { - this.random = new Random(); this.milliseconds = milliseconds; + + this.root = null; } - public MCTSAI2(MCTSAI2 other) { - this.random = other.random; + public MCTSAI2(MCTSAI2 other) { this.milliseconds = other.milliseconds; + + this.root = other.root; } @Override @@ -90,7 +104,7 @@ public class MCTSAI2> extends AbstractAI { @Override public long getMove(T game) { - final Node root = new Node(game, null, 0L); + root = findOrResetRoot(root, game); final long endTime = System.nanoTime() + milliseconds * 1_000_000L; @@ -102,8 +116,11 @@ public class MCTSAI2> extends AbstractAI { } final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; - return mostVisitedChild != null? mostVisitedChild.move : 0L; + root = findChildByMove(root, move); + + return move; } private Node mostVisitedChild(Node root) { @@ -120,8 +137,51 @@ public class MCTSAI2> extends AbstractAI { return mostVisitedChild; } + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { root = root.bestUCTChild(); } @@ -161,7 +221,9 @@ public class MCTSAI2> extends AbstractAI { if (copiedState.getWinner() == playerIndex) { return 1.0f; - } else if (copiedState.getWinner() >= 0) { + } + + if (copiedState.getWinner() >= 0) { return -1.0f; } @@ -173,11 +235,57 @@ public class MCTSAI2> extends AbstractAI { leaf.value += value; leaf.visits++; + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + value = -value; leaf = leaf.parent; } } + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + private long randomSetBit(long value) { if (0L == value) { return 0; diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java index efef955..2d85173 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java @@ -3,7 +3,13 @@ package org.toop.game.players.ai; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.AbstractAI; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; public class MCTSAI3> extends AbstractAI { private static class Node { @@ -20,6 +26,9 @@ public class MCTSAI3> extends AbstractAI { public float value; public int visits; + public boolean solved; + public float solvedValue; + public Node(TurnBasedGame state, Node parent, long move) { final long legalMoves = state.getLegalMoves(); @@ -35,6 +44,9 @@ public class MCTSAI3> extends AbstractAI { this.value = 0.0f; this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; } public Node(TurnBasedGame state) { @@ -46,6 +58,10 @@ public class MCTSAI3> extends AbstractAI { } public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + final float exploitation = value / visits; final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); @@ -63,30 +79,25 @@ public class MCTSAI3> extends AbstractAI { highestUCTChild = children[i]; highestUCT = childUCT; } - } return highestUCTChild; } } - private final Random random; + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - private Node root; private final int milliseconds; + private final int threads; - public MCTSAI3(int milliseconds) { - this.random = new Random(); - - this.root = null; + public MCTSAI3(int milliseconds, int threads) { this.milliseconds = milliseconds; + this.threads = threads; } - public MCTSAI3(MCTSAI3 other) { - this.random = other.random; - - this.root = other.root; + public MCTSAI3(MCTSAI3 other) { this.milliseconds = other.milliseconds; + this.threads = other.threads; } @Override @@ -96,23 +107,57 @@ public class MCTSAI3> extends AbstractAI { @Override public long getMove(T game) { - detectRoot(game); - + final ExecutorService pool = Executors.newFixedThreadPool(threads); final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + tasks.add(() -> { + final Node localRoot = new Node(game.deepCopy()); + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); } - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild != null? mostVisitedChild.move : 0L; + try { + final List> results = pool.invokeAll(tasks); - newRoot(move); + pool.shutdown(); - return move; + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } } private Node mostVisitedChild(Node root) { @@ -129,62 +174,8 @@ public class MCTSAI3> extends AbstractAI { return mostVisitedChild; } - private void detectRoot(T game) { - if (root == null) { - root = new Node(game.deepCopy()); - return; - } - - final long[] currentBoards = game.getBoard(); - final long[] rootBoards = root.state.getBoard(); - - boolean detected = true; - - for (int i = 0; i < rootBoards.length; i++) { - if (rootBoards[i] != currentBoards[i]) { - detected = false; - break; - } - } - - if (detected) { - return; - } - - for (int i = 0; i < root.expanded; i++) { - final Node child = root.children[i]; - - final long[] childBoards = child.state.getBoard(); - - detected = true; - - for (int j = 0; j < childBoards.length; j++) { - if (childBoards[j] != currentBoards[j]) { - detected = false; - break; - } - } - - if (detected) { - root = child; - return; - } - } - - root = new Node(game.deepCopy()); - } - - private void newRoot(long move) { - for (final Node child : root.children) { - if (child.move == move) { - root = child; - break; - } - } - } - private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { root = root.bestUCTChild(); } @@ -224,7 +215,9 @@ public class MCTSAI3> extends AbstractAI { if (copiedState.getWinner() == playerIndex) { return 1.0f; - } else if (copiedState.getWinner() >= 0) { + } + + if (copiedState.getWinner() >= 0) { return -1.0f; } @@ -236,18 +229,64 @@ public class MCTSAI3> extends AbstractAI { leaf.value += value; leaf.visits++; + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + value = -value; leaf = leaf.parent; } } + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + private long randomSetBit(long value) { if (0L == value) { return 0; } final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); + final int randomBitCount = random.get().nextInt(bitCount); for (int i = 0; i < randomBitCount; i++) { value &= value - 1; @@ -255,4 +294,4 @@ public class MCTSAI3> extends AbstractAI { return value & -value; } -} +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java new file mode 100644 index 0000000..262d9ae --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java @@ -0,0 +1,359 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI4> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); + + private final int milliseconds; + private final int threads; + + private final Node[] threadRoots; + + public MCTSAI4(int milliseconds, int threads) { + this.milliseconds = milliseconds; + this.threads = threads; + + this.threadRoots = new Node[threads]; + } + + public MCTSAI4(MCTSAI4 other) { + this.milliseconds = other.milliseconds; + this.threads = other.threads; + + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI4 deepCopy() { + return new MCTSAI4<>(this); + } + + @Override + public long getMove(T game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java new file mode 100644 index 0000000..452eaf8 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java @@ -0,0 +1,371 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI5> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public boolean solved; + public float solvedValue; + + public float heuristic; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + + this.solved = false; + this.solvedValue = 0.0f; + + this.heuristic = state.rateMove(move); + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } + + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + final float bias = heuristic / visits; + + return exploitation + exploration + bias; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + } + + return highestUCTChild; + } + } + + private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); + + private final int milliseconds; + private final int threads; + + private final Node[] threadRoots; + + public MCTSAI5(int milliseconds, int threads) { + this.milliseconds = milliseconds; + this.threads = threads; + + this.threadRoots = new Node[threads]; + } + + public MCTSAI5(MCTSAI5 other) { + this.milliseconds = other.milliseconds; + this.threads = other.threads; + + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI5 deepCopy() { + return new MCTSAI5<>(this); + } + + @Override + public long getMove(T game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + for (int i = 0; i < root.expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + private Node findChildByMove(Node root, long move) { + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + private boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + private Node selection(Node root) { + while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + + long move = 0L; + + if (random.get().nextFloat() > 0.9f) { + move = copiedState.heuristicMove(legalMoves); + } else { + move = randomSetBit(legalMoves); + } + + copiedState.play(move); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } + + if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (!leaf.solved) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; + } + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + node.solved = true; + + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (child.solved) { + if (child.solvedValue == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solvedValue == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = true; + node.solvedValue = 1.0f; + } else if (allChildrenSolved) { + node.solved = true; + node.solvedValue = foundDrawMove? 0.0f : -1.0f; + } + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file