From df93b44d19007c27809d57a4305aca3110e3563f Mon Sep 17 00:00:00 2001 From: ramollia <> Date: Wed, 7 Jan 2026 14:39:38 +0100 Subject: [PATCH] bitboard fix & mcts v2 & mcts v3. v3 still in progress and v4 coming soon --- app/src/main/java/org/toop/Main.java | 54 +++- .../app/widget/view/LocalMultiplayerView.java | 9 +- .../model/game/TurnBasedGame.java | 3 + .../model/player/AbstractPlayer.java | 1 - .../main/java/org/toop/game/BitboardGame.java | 23 +- .../game/games/reversi/BitboardReversi.java | 24 +- .../games/tictactoe/BitboardTicTacToe.java | 12 +- .../java/org/toop/game/players/ai/MCTSAI.java | 18 +- .../org/toop/game/players/ai/MCTSAI2.java | 195 +++++++++++++ .../org/toop/game/players/ai/MCTSAI3.java | 258 ++++++++++++++++++ 10 files changed, 568 insertions(+), 29 deletions(-) create mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI2.java create mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI3.java diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index 3b4fef3..874276a 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -1,9 +1,61 @@ package org.toop; import org.toop.app.App; +import org.toop.framework.gameFramework.model.player.AbstractPlayer; +import org.toop.framework.gameFramework.model.player.Player; +import org.toop.game.games.reversi.BitboardReversi; +import org.toop.game.games.tictactoe.BitboardTicTacToe; +import org.toop.game.players.ArtificialPlayer; +import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI2; +import org.toop.game.players.ai.MCTSAI3; +import org.toop.game.players.ai.RandomAI; public final class Main { static void main(String[] args) { - App.run(args); + // App.run(args); + testMCTS(10); } + + private static void testMCTS(int games) { + var random = new ArtificialPlayer<>(new RandomAI(), "Random AI"); + var v1 = new ArtificialPlayer<>(new MCTSAI(10), "MCTS V1 AI"); + var v2 = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); + var v2_2 = new ArtificialPlayer<>(new MCTSAI2(100), "MCTS V2_2 AI"); + var v3 = new ArtificialPlayer<>(new MCTSAI3(10), "MCTS V3 AI"); + + testAI(games, new Player[]{ v1, v2 }); + // testAI(games, new Player[]{ v1, v3 }); + + // testAI(games, new Player[]{ random, v3 }); + // testAI(games, new Player[]{ v2, v3 }); + testAI(games, new Player[]{ v2, v3 }); + // testAI(games, new Player[]{ v3, v2 }); + } + + private static void testAI(int games, Player[] ais) { + int wins = 0; + int ties = 0; + + for (int i = 0; i < games; i++) { + final BitboardReversi match = new BitboardReversi(ais); + + while (!match.isTerminal()) { + final int currentAI = match.getCurrentTurn(); + final long move = ais[currentAI].getMove(match); + + match.play(move); + } + + if (match.getWinner() < 0) { + ties++; + continue; + } + + wins += match.getWinner() == 0? 1 : 0; + } + + System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); + System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); + } } diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index 139a33c..3e9e675 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -16,6 +16,8 @@ import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.MCTSAI2; +import org.toop.game.players.ai.MCTSAI3; import org.toop.game.players.ai.MiniMaxAI; import org.toop.game.players.ai.RandomAI; import org.toop.local.AppContext; @@ -55,7 +57,7 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer<>(information.players[0].name); } else { - players[0] = new ArtificialPlayer<>(new RandomAI(), "Random AI"); + players[0] = new ArtificialPlayer<>(new MCTSAI(100), "MCTS AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer<>(information.players[1].name); @@ -83,12 +85,13 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer<>(information.players[0].name); } else { - players[0] = new ArtificialPlayer<>(new RandomAI(), "Random AI"); + // players[0] = new ArtificialPlayer<>(new RandomAI(), "Random AI"); + players[0] = new ArtificialPlayer<>(new MCTSAI3(50), "MCTS V3 AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer<>(information.players[1].name); } else { - players[1] = new ArtificialPlayer<>(new MCTSAI(1000), "MCTS AI"); + players[1] = new ArtificialPlayer<>(new MCTSAI2(50), "MCTS V2 AI"); } if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { new ShowEnableTutorialWidget( diff --git a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java index d4cb4df..41f9b8c 100644 --- a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java +++ b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java @@ -4,4 +4,7 @@ public interface TurnBasedGame> extends Playable, Dee int getCurrentTurn(); int getPlayerCount(); int getWinner(); + + PlayResult getState(); + boolean isTerminal(); } diff --git a/framework/src/main/java/org/toop/framework/gameFramework/model/player/AbstractPlayer.java b/framework/src/main/java/org/toop/framework/gameFramework/model/player/AbstractPlayer.java index 52b0de4..601b93d 100644 --- a/framework/src/main/java/org/toop/framework/gameFramework/model/player/AbstractPlayer.java +++ b/framework/src/main/java/org/toop/framework/gameFramework/model/player/AbstractPlayer.java @@ -11,7 +11,6 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame; */ public abstract class AbstractPlayer> implements Player { - private final Logger logger = LogManager.getLogger(this.getClass()); private final String name; /** diff --git a/game/src/main/java/org/toop/game/BitboardGame.java b/game/src/main/java/org/toop/game/BitboardGame.java index 4ebdb95..8e708b5 100644 --- a/game/src/main/java/org/toop/game/BitboardGame.java +++ b/game/src/main/java/org/toop/game/BitboardGame.java @@ -1,5 +1,7 @@ package org.toop.game; +import org.toop.framework.gameFramework.GameState; +import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.Player; @@ -11,6 +13,8 @@ public abstract class BitboardGame> implements TurnBas private final int columnSize; private final int rowSize; + protected PlayResult state; + private Player[] players; // long is 64 bits. Every game has a limit of 64 cells maximum. @@ -20,6 +24,9 @@ public abstract class BitboardGame> implements TurnBas public BitboardGame(int columnSize, int rowSize, int playerCount, Player[] players) { this.columnSize = columnSize; this.rowSize = rowSize; + + this.state = new PlayResult(GameState.NORMAL, -1); + this.players = players; this.playerBitboard = new long[playerCount]; @@ -30,6 +37,8 @@ public abstract class BitboardGame> implements TurnBas this.columnSize = other.columnSize; this.rowSize = other.rowSize; + this.state = other.state; + this.playerBitboard = other.playerBitboard.clone(); this.currentTurn = other.currentTurn; this.players = Arrays.stream(other.players) @@ -61,7 +70,9 @@ public abstract class BitboardGame> implements TurnBas return getCurrentPlayerIndex(); } - public Player getPlayer(int index) {return players[index];} + public Player getPlayer(int index) { + return players[index]; + } public int getCurrentPlayerIndex() { return currentTurn % playerBitboard.length; @@ -75,9 +86,17 @@ public abstract class BitboardGame> implements TurnBas return players[getCurrentPlayerIndex()]; } + @Override + public PlayResult getState() { + return state; + } + @Override + public boolean isTerminal() { + return state.state() == GameState.WIN || state.state() == GameState.DRAW; + } - @Override + @Override public long[] getBoard() {return this.playerBitboard;} public void nextTurn() { diff --git a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java index 23816d7..aa8d5b8 100644 --- a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java +++ b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java @@ -175,7 +175,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction << 1) & mask; direction |= (direction << 1) & mask; - if (((direction << 1) & player) != 0) { + if (((direction << 1) & player & notAFile) != 0) { flips |= direction; } @@ -189,7 +189,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction >>> 1) & mask; direction |= (direction >>> 1) & mask; - if (((direction >>> 1) & player) != 0) { + if (((direction >>> 1) & player & notHFile) != 0) { flips |= direction; } @@ -203,7 +203,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction << 9) & mask; direction |= (direction << 9) & mask; - if (((direction << 9) & player) != 0) { + if (((direction << 9) & player & notAFile) != 0) { flips |= direction; } @@ -217,7 +217,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction << 7) & mask; direction |= (direction << 7) & mask; - if (((direction << 7) & player) != 0) { + if (((direction << 7) & player & notHFile) != 0) { flips |= direction; } @@ -231,7 +231,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction >>> 7) & mask; direction |= (direction >>> 7) & mask; - if (((direction >>> 7) & player) != 0) { + if (((direction >>> 7) & player & notAFile) != 0) { flips |= direction; } @@ -245,7 +245,7 @@ public class BitboardReversi extends BitboardGame { direction |= (direction >>> 9) & mask; direction |= (direction >>> 9) & mask; - if (((direction >>> 9) & player) != 0) { + if (((direction >>> 9) & player & notHFile) != 0) { flips |= direction; } @@ -280,16 +280,20 @@ public class BitboardReversi extends BitboardGame { int winner = getWinner(); if (winner == -1) { - return new PlayResult(GameState.DRAW, -1); + state = new PlayResult(GameState.DRAW, -1); + return state; } - return new PlayResult(GameState.WIN, winner); + state = new PlayResult(GameState.WIN, winner); + return state; } - return new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex()); + state = new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex()); + return state; } - return new PlayResult(GameState.NORMAL, getCurrentPlayerIndex()); + state = new PlayResult(GameState.NORMAL, getCurrentPlayerIndex()); + return state; } public Score getScore() { diff --git a/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java b/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java index 0927431..e06eff2 100644 --- a/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java +++ b/game/src/main/java/org/toop/game/games/tictactoe/BitboardTicTacToe.java @@ -39,7 +39,8 @@ public class BitboardTicTacToe extends BitboardGame { public PlayResult play(long move) { // Player loses if move is invalid if ((move & getLegalMoves()) == 0 || Long.bitCount(move) != 1){ - return new PlayResult(GameState.WIN, getNextPlayer()); + state = new PlayResult(GameState.WIN, getNextPlayer()); + return state; } // Move is legal, make move @@ -50,7 +51,8 @@ public class BitboardTicTacToe extends BitboardGame { // Check if current player won if (checkWin(playerBitboard)) { - return new PlayResult(GameState.WIN, getCurrentPlayerIndex()); + state = new PlayResult(GameState.WIN, getCurrentPlayerIndex()); + return state; } // Proceed to next turn @@ -59,11 +61,13 @@ public class BitboardTicTacToe extends BitboardGame { // Check for early draw if (getLegalMoves() == 0L || checkEarlyDraw()) { - return new PlayResult(GameState.DRAW, -1); + state = new PlayResult(GameState.DRAW, -1); + return state; } // Nothing weird happened, continue on as normal - return new PlayResult(GameState.NORMAL, -1); + state = new PlayResult(GameState.NORMAL, -1); + return state; } private boolean checkWin(long board) { diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java index 41cfd77..7a30caa 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java @@ -41,15 +41,19 @@ public class MCTSAI> extends AbstractAI { return expanded >= children.length; } - public Node bestUCTChild(float explorationFactor) { + float calculateUCT() { + float exploitation = visits <= 0? 0 : value / visits; + float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { int bestChildIndex = -1; float bestScore = Float.NEGATIVE_INFINITY; for (int i = 0; i < expanded; i++) { - float exploitation = children[i].visits <= 0? 0 : children[i].value / children[i].visits; - float exploration = explorationFactor * (float)(Math.sqrt(Math.log(visits) / (children[i].visits + 0.001f))); - - float score = exploitation + exploration; + final float score = calculateUCT(); if (score > bestScore) { bestChildIndex = i; @@ -109,14 +113,12 @@ public class MCTSAI> extends AbstractAI { } } - System.out.println("Visit count: " + root.visits); - return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves()); } private Node selection(Node node) { while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) { - node = node.bestUCTChild(1.41f); + node = node.bestUCTChild(); } return node; diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java new file mode 100644 index 0000000..bde88b2 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java @@ -0,0 +1,195 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.Random; + +public class MCTSAI2> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + + } + + return highestUCTChild; + } + } + + private final Random random; + private final int milliseconds; + + public MCTSAI2(int milliseconds) { + this.random = new Random(); + this.milliseconds = milliseconds; + } + + public MCTSAI2(MCTSAI2 other) { + this.random = other.random; + this.milliseconds = other.milliseconds; + } + + @Override + public MCTSAI2 deepCopy() { + return new MCTSAI2<>(this); + } + + @Override + public long getMove(T game) { + final Node root = new Node(game, null, 0L); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + final Node mostVisitedChild = mostVisitedChild(root); + + return mostVisitedChild != null? mostVisitedChild.move : 0L; + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private Node selection(Node root) { + while (root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } else if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + value = -value; + leaf = leaf.parent; + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java new file mode 100644 index 0000000..efef955 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java @@ -0,0 +1,258 @@ +package org.toop.game.players.ai; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AbstractAI; + +import java.util.Random; + +public class MCTSAI3> extends AbstractAI { + private static class Node { + public TurnBasedGame state; + + public long move; + public long unexpandedMoves; + + public Node parent; + + public Node[] children; + public int expanded; + + public float value; + public int visits; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); + + this.state = state; + + this.move = move; + this.unexpandedMoves = legalMoves; + + this.parent = parent; + + this.children = new Node[Long.bitCount(legalMoves)]; + this.expanded = 0; + + this.value = 0.0f; + this.visits = 0; + } + + public Node(TurnBasedGame state) { + this(state, null, 0L); + } + + public boolean isFullyExpanded() { + return expanded == children.length; + } + + public float calculateUCT(int parentVisits) { + final float exploitation = value / visits; + final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + + return exploitation + exploration; + } + + public Node bestUCTChild() { + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; + + for (int i = 0; i < expanded; i++) { + final float childUCT = children[i].calculateUCT(visits); + + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; + } + + } + + return highestUCTChild; + } + } + + private final Random random; + + private Node root; + private final int milliseconds; + + public MCTSAI3(int milliseconds) { + this.random = new Random(); + + this.root = null; + this.milliseconds = milliseconds; + } + + public MCTSAI3(MCTSAI3 other) { + this.random = other.random; + + this.root = other.root; + this.milliseconds = other.milliseconds; + } + + @Override + public MCTSAI3 deepCopy() { + return new MCTSAI3<>(this); + } + + @Override + public long getMove(T game) { + detectRoot(game); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild != null? mostVisitedChild.move : 0L; + + newRoot(move); + + return move; + } + + private Node mostVisitedChild(Node root) { + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < root.expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + private void detectRoot(T game) { + if (root == null) { + root = new Node(game.deepCopy()); + return; + } + + final long[] currentBoards = game.getBoard(); + final long[] rootBoards = root.state.getBoard(); + + boolean detected = true; + + for (int i = 0; i < rootBoards.length; i++) { + if (rootBoards[i] != currentBoards[i]) { + detected = false; + break; + } + } + + if (detected) { + return; + } + + for (int i = 0; i < root.expanded; i++) { + final Node child = root.children[i]; + + final long[] childBoards = child.state.getBoard(); + + detected = true; + + for (int j = 0; j < childBoards.length; j++) { + if (childBoards[j] != currentBoards[j]) { + detected = false; + break; + } + } + + if (detected) { + root = child; + return; + } + } + + root = new Node(game.deepCopy()); + } + + private void newRoot(long move) { + for (final Node child : root.children) { + if (child.move == move) { + root = child; + break; + } + } + } + + private Node selection(Node root) { + while (root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); + } + + return root; + } + + private Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.expanded] = expandedChild; + leaf.expanded++; + + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; + } + + private float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); + } + + if (copiedState.getWinner() == playerIndex) { + return 1.0f; + } else if (copiedState.getWinner() >= 0) { + return -1.0f; + } + + return 0.0f; + } + + private void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + value = -value; + leaf = leaf.parent; + } + } + + private long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } +}