From a6b2356a5edfa5155794c7e47dad3b341fe4e7e0 Mon Sep 17 00:00:00 2001 From: Bas Antonius de Jong <49651652+BAFGdeJong@users.noreply.github.com> Date: Sat, 17 Jan 2026 04:05:11 +0100 Subject: [PATCH] update mcts, incremental merge (#311) * mcts v1, v2, v3, v4 done. v5 wip * update mcts * mcts v1, v2, v3, v4 done. v5 wip * update mcts * Merge changes on dev * update mcts --------- Co-authored-by: ramollia <> --- app/src/main/java/org/toop/Main.java | 85 +++-- app/src/main/java/org/toop/app/Server.java | 5 +- .../app/widget/view/LocalMultiplayerView.java | 14 +- .../org/toop/framework/game/BitboardGame.java | 1 + .../game/games/reversi/BitboardReversi.java | 45 ++- .../games/tictactoe/BitboardTicTacToe.java | 10 + .../game/players/ArtificialPlayer.java | 4 + .../model/game/TurnBasedGame.java | 3 + .../java/org/toop/game/players/ai/MCTSAI.java | 321 ++++++++++++------ .../org/toop/game/players/ai/MCTSAI2.java | 195 ----------- .../org/toop/game/players/ai/MCTSAI3.java | 258 -------------- .../toop/game/players/ai/mcts/MCTSAI1.java | 39 +++ .../toop/game/players/ai/mcts/MCTSAI2.java | 49 +++ .../toop/game/players/ai/mcts/MCTSAI3.java | 92 +++++ .../toop/game/players/ai/mcts/MCTSAI4.java | 107 ++++++ 15 files changed, 616 insertions(+), 612 deletions(-) delete mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI2.java delete mode 100644 game/src/main/java/org/toop/game/players/ai/MCTSAI3.java create mode 100644 game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java create mode 100644 game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java create mode 100644 game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java create mode 100644 game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index bd330d2..8b93ac6 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -1,53 +1,68 @@ package org.toop; import org.toop.app.App; +import org.toop.framework.game.games.reversi.BitboardReversi; +import org.toop.framework.game.players.ArtificialPlayer; +import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.RandomAI; +import org.toop.game.players.ai.mcts.MCTSAI1; +import org.toop.game.players.ai.mcts.MCTSAI2; +import org.toop.game.players.ai.mcts.MCTSAI3; +import org.toop.game.players.ai.mcts.MCTSAI4; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; public final class Main { static void main(String[] args) { App.run(args); - // testMCTS(10); + + // final ExecutorService executor = Executors.newFixedThreadPool(1); + // executor.execute(() -> testAIs(25)); } - // Voor onderzoek - // private static void testMCTS(int games) { - // var random = new ArtificialPlayer<>(new RandomAI(), "Random AI"); - // var v1 = new ArtificialPlayer<>(new MCTSAI(10), "MCTS V1 AI"); - // var v2 = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); - // var v2_2 = new ArtificialPlayer<>(new MCTSAI2(100), "MCTS V2_2 AI"); - // var v3 = new ArtificialPlayer<>(new MCTSAI3(10), "MCTS V3 AI"); + private static void testAIs(int games) { + var versions = new ArtificialPlayer[5]; + versions[0] = new ArtificialPlayer(new RandomAI(), "Random AI"); + versions[1] = new ArtificialPlayer(new MCTSAI1(1000), "MCTS V1 AI"); + versions[2] = new ArtificialPlayer(new MCTSAI2(1000), "MCTS V2 AI"); + versions[3] = new ArtificialPlayer(new MCTSAI3(10, 10), "MCTS V3 AI"); + versions[4] = new ArtificialPlayer(new MCTSAI4(10, 10), "MCTS V4 AI"); - // testAI(games, new Player[]{ v1, v2 }); - // // testAI(games, new Player[]{ v1, v3 }); + for (int i = 0; i < versions.length; i++) { + for (int j = i + 1; j < versions.length; j++) { + final int playerIndex1 = i % versions.length; + final int playerIndex2 = j % versions.length; - // // testAI(games, new Player[]{ random, v3 }); - // // testAI(games, new Player[]{ v2, v3 }); - // testAI(games, new Player[]{ v2, v3 }); - // // testAI(games, new Player[]{ v3, v2 }); - // } + testAIVSAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]}); + } + } + } - // private static void testAI(int games, Player[] ais) { - // int wins = 0; - // int ties = 0; + private static void testAIVSAI(int games, ArtificialPlayer[] ais) { + int wins = 0; + int ties = 0; - // for (int i = 0; i < games; i++) { - // final BitboardReversi match = new BitboardReversi(ais); + for (int i = 0; i < games; i++) { + final BitboardReversi match = new BitboardReversi(); + match.init(ais); - // while (!match.isTerminal()) { - // final int currentAI = match.getCurrentTurn(); - // final long move = ais[currentAI].getMove(match); + while (!match.isTerminal()) { + final int currentAI = match.getCurrentTurn(); + final long move = ais[currentAI].getMove(match); - // match.play(move); - // } + match.play(move); + } - // if (match.getWinner() < 0) { - // ties++; - // continue; - // } + if (match.getWinner() < 0) { + ties++; + continue; + } - // wins += match.getWinner() == 0? 1 : 0; - // } + wins += match.getWinner() == 0? 1 : 0; + } - // System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); - // System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); - // } -} + System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName()); + System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games); + } +} \ No newline at end of file diff --git a/app/src/main/java/org/toop/app/Server.java b/app/src/main/java/org/toop/app/Server.java index 43129e1..06a5b75 100644 --- a/app/src/main/java/org/toop/app/Server.java +++ b/app/src/main/java/org/toop/app/Server.java @@ -20,8 +20,7 @@ import org.toop.framework.networking.connection.clients.TournamentNetworkingClie import org.toop.framework.networking.connection.events.NetworkEvents; import org.toop.framework.networking.connection.types.NetworkingConnector; import org.toop.framework.networking.server.gateway.NettyGatewayServer; -import org.toop.framework.game.players.LocalPlayer; -import org.toop.game.players.ai.MCTSAI3; +import org.toop.game.players.ai.mcts.MCTSAI3; import org.toop.local.AppContext; import java.util.Arrays; @@ -211,7 +210,7 @@ public final class Server { Player[] players = new Player[2]; - players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000), user); + players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, Runtime.getRuntime().availableProcessors()), user); players[opponentStartingTurn] = new OnlinePlayer(response.opponent()); switch (type) { diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index afee3f2..e19732a 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -4,6 +4,7 @@ import javafx.application.Platform; import org.toop.app.GameInformation; import org.toop.app.gameControllers.ReversiBitController; import org.toop.app.gameControllers.TicTacToeBitController; +import org.toop.framework.game.players.LocalPlayer; import org.toop.framework.gameFramework.controller.GameController; import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.game.players.ArtificialPlayer; @@ -12,11 +13,10 @@ import org.toop.app.widget.complex.PlayerInfoWidget; import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; -import org.toop.framework.game.players.LocalPlayer; -import org.toop.game.players.ai.MCTSAI; -import org.toop.game.players.ai.MCTSAI2; -import org.toop.game.players.ai.MCTSAI3; import org.toop.game.players.ai.MiniMaxAI; +import org.toop.game.players.ai.mcts.MCTSAI1; +import org.toop.game.players.ai.mcts.MCTSAI3; +import org.toop.game.players.ai.mcts.MCTSAI4; import org.toop.local.AppContext; import javafx.geometry.Pos; @@ -54,7 +54,7 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer(information.players[0].name); } else { - players[0] = new ArtificialPlayer(new MCTSAI(100), "MCTS AI"); + players[0] = new ArtificialPlayer(new MCTSAI1(100), "MCTS AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer(information.players[1].name); @@ -83,12 +83,12 @@ public class LocalMultiplayerView extends ViewWidget { players[0] = new LocalPlayer(information.players[0].name); } else { // players[0] = new ArtificialPlayer(new RandomAI(), "Random AI"); - players[0] = new ArtificialPlayer(new MCTSAI3(50), "MCTS V3 AI"); + players[0] = new ArtificialPlayer(new MCTSAI4(500, 4), "MCTS V4 AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer(information.players[1].name); } else { - players[1] = new ArtificialPlayer(new MCTSAI(50), "MCTS V1 AI"); + players[1] = new ArtificialPlayer(new MCTSAI1(500), "MCTS V1 AI"); } if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { new ShowEnableTutorialWidget( diff --git a/framework/src/main/java/org/toop/framework/game/BitboardGame.java b/framework/src/main/java/org/toop/framework/game/BitboardGame.java index 76997c5..26d7c1e 100644 --- a/framework/src/main/java/org/toop/framework/game/BitboardGame.java +++ b/framework/src/main/java/org/toop/framework/game/BitboardGame.java @@ -6,6 +6,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.Player; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; // There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy public abstract class BitboardGame implements TurnBasedGame { diff --git a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java index 102b94d..c093dd7 100644 --- a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java +++ b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java @@ -321,8 +321,51 @@ public class BitboardReversi extends BitboardGame { else if (blackCount > whiteCount){ return 0; } - else{ + else { return 1; } } + + @Override + public float rateMove(long move) { + final long corners = 0x8100000000000081L; + + if ((move & corners) != 0L) { + return 0.4f; + } + + final long xSquares = 0x0042000000004200L; + + if ((move & xSquares) != 0) { + return -0.4f; + } + + final long cSquares = 0x4281000000008142L; + + if ((move & cSquares) != 0) { + return -0.1f; + } + + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + long bestMove = 0L; + float bestMoveRate = Float.NEGATIVE_INFINITY; + + while (legalMoves != 0L) { + final long move = legalMoves & -legalMoves; + final float moveRate = rateMove(move); + + if (moveRate > bestMoveRate) { + bestMove = move; + bestMoveRate = moveRate; + } + + legalMoves &= ~move; + } + + return bestMove; + } } \ No newline at end of file diff --git a/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java b/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java index fa89efc..cf0f35a 100644 --- a/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java +++ b/framework/src/main/java/org/toop/framework/game/games/tictactoe/BitboardTicTacToe.java @@ -110,4 +110,14 @@ public class BitboardTicTacToe extends BitboardGame { public BitboardTicTacToe deepCopy() { return new BitboardTicTacToe(this); } + + @Override + public float rateMove(long move) { + return 0.0f; + } + + @Override + public long heuristicMove(long legalMoves) { + return legalMoves; + } } \ No newline at end of file diff --git a/framework/src/main/java/org/toop/framework/game/players/ArtificialPlayer.java b/framework/src/main/java/org/toop/framework/game/players/ArtificialPlayer.java index 23bb2bc..43167c9 100644 --- a/framework/src/main/java/org/toop/framework/game/players/ArtificialPlayer.java +++ b/framework/src/main/java/org/toop/framework/game/players/ArtificialPlayer.java @@ -57,4 +57,8 @@ public class ArtificialPlayer extends AbstractPlayer { public ArtificialPlayer deepCopy() { return new ArtificialPlayer(this); } + + public AI getAi() { + return ai; + } } diff --git a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java index 6f5f8a7..6cd5996 100644 --- a/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java +++ b/framework/src/main/java/org/toop/framework/gameFramework/model/game/TurnBasedGame.java @@ -13,4 +13,7 @@ public interface TurnBasedGame extends DeepCopyable { PlayResult play(long move); PlayResult getState(); boolean isTerminal(); + + float rateMove(long move); + long heuristicMove(long legalMoves); } diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java index 8687846..495587d 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java @@ -1,193 +1,288 @@ package org.toop.game.players.ai; -import org.toop.framework.gameFramework.GameState; -import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.AbstractAI; import java.util.Random; -public class MCTSAI extends AbstractAI { - private static class Node { +public abstract class MCTSAI extends AbstractAI { + protected static class Node { public TurnBasedGame state; + public long move; + public long unexpandedMoves; public Node parent; - - public int expanded; public Node[] children; - public int visits; public float value; + public int visits; + + public float heuristic; + + public float solved; + + public Node(TurnBasedGame state, Node parent, long move) { + final long legalMoves = state.getLegalMoves(); - public Node(TurnBasedGame state, long move, Node parent) { this.state = state; + this.move = move; + this.unexpandedMoves = legalMoves; this.parent = parent; + this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - this.children = new Node[Long.bitCount(state.getLegalMoves())]; - - this.visits = 0; this.value = 0.0f; + this.visits = 0; + + this.heuristic = state.rateMove(move); + + this.solved = Float.NaN; } public Node(TurnBasedGame state) { - this(state, 0L, null); + this(state, null, 0L); + } + + public int getExpanded() { + return children.length - Long.bitCount(unexpandedMoves); } public boolean isFullyExpanded() { - return expanded >= children.length; + return unexpandedMoves == 0L; } - float calculateUCT() { - float exploitation = visits <= 0? 0 : value / visits; - float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits)); + public float calculateUCT(int parentVisits) { + if (visits == 0) { + return Float.POSITIVE_INFINITY; + } - return exploitation + exploration; + final float exploitation = value / visits; + final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits)); + final float bias = heuristic * 10.0f / (visits + 1); + + return exploitation + exploration + bias; } public Node bestUCTChild() { - int bestChildIndex = -1; - float bestScore = Float.NEGATIVE_INFINITY; + final int expanded = getExpanded(); + + Node highestUCTChild = null; + float highestUCT = Float.NEGATIVE_INFINITY; for (int i = 0; i < expanded; i++) { - final float score = calculateUCT(); + final float childUCT = children[i].calculateUCT(visits); - if (score > bestScore) { - bestChildIndex = i; - bestScore = score; + if (childUCT > highestUCT) { + highestUCTChild = children[i]; + highestUCT = childUCT; } } - return bestChildIndex >= 0? children[bestChildIndex] : this; + return highestUCTChild; } } - private final int milliseconds; + protected static final ThreadLocal random = ThreadLocal.withInitial(Random::new); + + protected final int milliseconds; + + protected int lastIterations; public MCTSAI(int milliseconds) { this.milliseconds = milliseconds; + + this.lastIterations = 0; } public MCTSAI(MCTSAI other) { this.milliseconds = other.milliseconds; + + this.lastIterations = other.lastIterations; } - @Override - public MCTSAI deepCopy() { - return new MCTSAI(this); + public int getLastIterations() { + return lastIterations; } - @Override - public long getMove(TurnBasedGame game) { - Node root = new Node(game.deepCopy()); - - long endTime = System.currentTimeMillis() + milliseconds; - - while (System.currentTimeMillis() <= endTime) { - Node node = selection(root); - long legalMoves = node.state.getLegalMoves(); - - if (legalMoves != 0) { - node = expansion(node, legalMoves); - } - - float result = 0.0f; - - if (node.state.getLegalMoves() != 0) { - result = simulation(node.state, game.getCurrentTurn()); - } - - backPropagation(node, result); + protected Node selection(Node root) { + // while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) { + while (root.isFullyExpanded() && !root.state.isTerminal()) { + root = root.bestUCTChild(); } - int mostVisitedIndex = -1; - int mostVisits = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisits) { - mostVisitedIndex = i; - mostVisits = root.children[i].visits; - } - } - - return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves()); + return root; } - private Node selection(Node node) { - while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) { - node = node.bestUCTChild(); + protected Node expansion(Node leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; } - return node; + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.getExpanded()] = expandedChild; + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; } - private Node expansion(Node node, long legalMoves) { - for (int i = 0; i < node.expanded; i++) { - legalMoves &= ~node.children[i].move; + protected float simulation(Node leaf) { + final TurnBasedGame copiedState = leaf.state.deepCopy(); + final int playerIndex = 1 - copiedState.getCurrentTurn(); + + while (!copiedState.isTerminal()) { + final long legalMoves = copiedState.getLegalMoves(); + final long randomMove = randomSetBit(legalMoves); + + copiedState.play(randomMove); } - if (legalMoves == 0L) { - return node; + if (copiedState.getWinner() == playerIndex) { + return 1.0f; } - long move = randomSetBit(legalMoves); - - TurnBasedGame copy = node.state.deepCopy(); - copy.play(move); - - Node newlyExpanded = new Node(copy, move, node); - - node.children[node.expanded] = newlyExpanded; - node.expanded++; - - return newlyExpanded; - } - - private float simulation(TurnBasedGame state, int playerIndex) { - TurnBasedGame copy = state.deepCopy(); - long legalMoves = copy.getLegalMoves(); - PlayResult result = null; - - while (legalMoves != 0) { - result = copy.play(randomSetBit(legalMoves)); - legalMoves = copy.getLegalMoves(); - } - - if (result.state() == GameState.WIN) { - if (result.player() == playerIndex) { - return 1.0f; - } - + if (copiedState.getWinner() >= 0) { return -1.0f; } - return -0.2f; + return 0.0f; } - private void backPropagation(Node node, float value) { - while (node != null) { - node.visits++; - node.value += value; - node = node.parent; + protected void backPropagation(Node leaf, float value) { + while (leaf != null) { + leaf.value += value; + leaf.visits++; + + if (Float.isNaN(leaf.solved)) { + updateSolvedStatus(leaf); + } + + value = -value; + leaf = leaf.parent; } } - public static long randomSetBit(long value) { - Random random = new Random(); + protected Node mostVisitedChild(Node root) { + final int expanded = root.getExpanded(); - int count = Long.bitCount(value); - int target = random.nextInt(count); + Node mostVisitedChild = null; + int mostVisited = -1; - while (true) { - int bit = Long.numberOfTrailingZeros(value); - if (target == 0) { - return 1L << bit; + for (int i = 0; i < expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; } + } + + return mostVisitedChild; + } + + protected Node findOrResetRoot(Node root, TurnBasedGame game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + final int expanded = root.getExpanded(); + + for (int i = 0; i < expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + protected Node findChildByMove(Node root, long move) { + final int expanded = root.getExpanded(); + + for (int i = 0; i < expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + protected boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + protected long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { value &= value - 1; - target--; + } + + return value & -value; + } + + private void updateSolvedStatus(Node node) { + if (node.state.isTerminal()) { + final int winner = node.state.getWinner(); + final int mover = 1 - node.state.getCurrentTurn(); + + node.solved = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + + return; + } + + if (node.isFullyExpanded()) { + boolean allChildrenSolved = true; + boolean foundWinningMove = false; + boolean foundDrawMove = false; + + for (final Node child : node.children) { + if (!Float.isNaN(child.solved)) { + if (child.solved == -1.0f) { + foundWinningMove = true; + break; + } + + if (child.solved == 0.0f) { + foundDrawMove = true; + } + } else { + allChildrenSolved = false; + } + } + + if (foundWinningMove) { + node.solved = 1.0f; + } else if (allChildrenSolved) { + node.solved = foundDrawMove? 0.0f : -1.0f; + } } } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java deleted file mode 100644 index 872c693..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java +++ /dev/null @@ -1,195 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.Random; - -public class MCTSAI2 extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - - } - - return highestUCTChild; - } - } - - private final Random random; - private final int milliseconds; - - public MCTSAI2(int milliseconds) { - this.random = new Random(); - this.milliseconds = milliseconds; - } - - public MCTSAI2(MCTSAI2 other) { - this.random = other.random; - this.milliseconds = other.milliseconds; - } - - @Override - public MCTSAI2 deepCopy() { - return new MCTSAI2(this); - } - - @Override - public long getMove(TurnBasedGame game) { - final Node root = new Node(game, null, 0L); - - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - final Node mostVisitedChild = mostVisitedChild(root); - - return mostVisitedChild != null? mostVisitedChild.move : 0L; - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - final long randomMove = randomSetBit(legalMoves); - - copiedState.play(randomMove); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } else if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - value = -value; - leaf = leaf.parent; - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java deleted file mode 100644 index fe892d1..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java +++ /dev/null @@ -1,258 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.Random; - -public class MCTSAI3 extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - - } - - return highestUCTChild; - } - } - - private final Random random; - - private Node root; - private final int milliseconds; - - public MCTSAI3(int milliseconds) { - this.random = new Random(); - - this.root = null; - this.milliseconds = milliseconds; - } - - public MCTSAI3(MCTSAI3 other) { - this.random = other.random; - - this.root = other.root; - this.milliseconds = other.milliseconds; - } - - @Override - public MCTSAI3 deepCopy() { - return new MCTSAI3(this); - } - - @Override - public long getMove(TurnBasedGame game) { - detectRoot(game); - - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild != null? mostVisitedChild.move : 0L; - - newRoot(move); - - return move; - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private void detectRoot(TurnBasedGame game) { - if (root == null) { - root = new Node(game.deepCopy()); - return; - } - - final long[] currentBoards = game.getBoard(); - final long[] rootBoards = root.state.getBoard(); - - boolean detected = true; - - for (int i = 0; i < rootBoards.length; i++) { - if (rootBoards[i] != currentBoards[i]) { - detected = false; - break; - } - } - - if (detected) { - return; - } - - for (int i = 0; i < root.expanded; i++) { - final Node child = root.children[i]; - - final long[] childBoards = child.state.getBoard(); - - detected = true; - - for (int j = 0; j < childBoards.length; j++) { - if (childBoards[j] != currentBoards[j]) { - detected = false; - break; - } - } - - if (detected) { - root = child; - return; - } - } - - root = new Node(game.deepCopy()); - } - - private void newRoot(long move) { - for (final Node child : root.children) { - if (child.move == move) { - root = child; - break; - } - } - } - - private Node selection(Node root) { - while (root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - final long randomMove = randomSetBit(legalMoves); - - copiedState.play(randomMove); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } else if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - value = -value; - leaf = leaf.parent; - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java new file mode 100644 index 0000000..6733e43 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java @@ -0,0 +1,39 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +public class MCTSAI1 extends MCTSAI { + public MCTSAI1(int milliseconds) { + super(milliseconds); + } + + public MCTSAI1(MCTSAI1 other) { + super(other); + } + + @Override + public MCTSAI1 deepCopy() { + return new MCTSAI1(this); + } + + @Override + public long getMove(TurnBasedGame game) { + final Node root = new Node(game, null, 0L); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java new file mode 100644 index 0000000..c7247a2 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java @@ -0,0 +1,49 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +public class MCTSAI2 extends MCTSAI { + private Node root; + + public MCTSAI2(int milliseconds) { + super(milliseconds); + + this.root = null; + } + + public MCTSAI2(MCTSAI2 other) { + super(other); + + this.root = other.root; + } + + @Override + public MCTSAI2 deepCopy() { + return new MCTSAI2(this); + } + + @Override + public long getMove(TurnBasedGame game) { + root = findOrResetRoot(root, game); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + root = findChildByMove(root, move); + + return move; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java new file mode 100644 index 0000000..ee1e202 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java @@ -0,0 +1,92 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI3 extends MCTSAI { + private final int threads; + + public MCTSAI3(int milliseconds, int threads) { + super(milliseconds); + + this.threads = threads; + } + + public MCTSAI3(MCTSAI3 other) { + super(other); + + this.threads = other.threads; + } + + @Override + public MCTSAI3 deepCopy() { + return new MCTSAI3(this); + } + + @Override + public long getMove(TurnBasedGame game) { + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + tasks.add(() -> { + final Node localRoot = new Node(game.deepCopy()); + + // while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } catch (Exception _) { + lastIterations = 0; + + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java new file mode 100644 index 0000000..c89839b --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java @@ -0,0 +1,107 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI4 extends MCTSAI { + private final int threads; + private final Node[] threadRoots; + + public MCTSAI4(int milliseconds, int threads) { + super(milliseconds); + + this.threads = threads; + this.threadRoots = new Node[threads]; + } + + public MCTSAI4(MCTSAI4 other) { + super(other); + + this.threads = other.threads; + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI4 deepCopy() { + return new MCTSAI4(this); + } + + @Override + public long getMove(TurnBasedGame game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + // while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + lastIterations = 0; + + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } +} \ No newline at end of file