diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index a698ba6..d3ef1db 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -3,39 +3,39 @@ package org.toop; import org.toop.app.App; import org.toop.framework.gameFramework.model.player.Player; import org.toop.game.games.reversi.BitboardReversi; -import org.toop.game.games.tictactoe.BitboardTicTacToe; import org.toop.game.players.ArtificialPlayer; -import org.toop.game.players.ai.MCTSAI1; -import org.toop.game.players.ai.MCTSAI2; -import org.toop.game.players.ai.MCTSAI3; -import org.toop.game.players.ai.MCTSAI4; -import org.toop.game.players.ai.MCTSAI5; +import org.toop.game.players.ai.MCTSAI; +import org.toop.game.players.ai.RandomAI; +import org.toop.game.players.ai.mcts.MCTSAI1; +import org.toop.game.players.ai.mcts.MCTSAI2; +import org.toop.game.players.ai.mcts.MCTSAI3; +import org.toop.game.players.ai.mcts.MCTSAI4; public final class Main { static void main(String[] args) { - App.run(args); - // testMCTS(100); + // App.run(args); + testMCTS(25); } private static void testMCTS(int games) { var versions = new ArtificialPlayer[5]; - versions[0] = new ArtificialPlayer<>(new MCTSAI1(10), "MCTS V1 AI"); - versions[1] = new ArtificialPlayer<>(new MCTSAI2(10), "MCTS V2 AI"); - versions[2] = new ArtificialPlayer<>(new MCTSAI3(10, 10), "MCTS V3 AI"); - versions[3] = new ArtificialPlayer<>(new MCTSAI4(10, 10), "MCTS V4 AI"); - versions[4] = new ArtificialPlayer<>(new MCTSAI5(10, 10), "MCTS V5 AI"); + versions[0] = new ArtificialPlayer<>(new RandomAI(), "Random AI"); + versions[1] = new ArtificialPlayer<>(new MCTSAI1(1000), "MCTS V1 AI"); + versions[2] = new ArtificialPlayer<>(new MCTSAI2(1000), "MCTS V2 AI"); + versions[3] = new ArtificialPlayer<>(new MCTSAI3(10, 10), "MCTS V3 AI"); + versions[4] = new ArtificialPlayer<>(new MCTSAI4(10, 10), "MCTS V4 AI"); - for (int i = 2; i < versions.length; i++) { + for (int i = 0; i < versions.length; i++) { for (int j = i + 1; j < versions.length; j++) { final int playerIndex1 = i % versions.length; final int playerIndex2 = j % versions.length; - testAI(games, new Player[] { versions[playerIndex1], versions[playerIndex2]}); + testAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]}); } } } - private static void testAI(int games, Player[] ais) { + private static void testAI(int games, ArtificialPlayer[] ais) { int wins = 0; int ties = 0; @@ -47,6 +47,11 @@ public final class Main { final long move = ais[currentAI].getMove(match); match.play(move); + + if (ais[currentAI].getAi() instanceof MCTSAI mcts) { + final int lastIterations = mcts.getLastIterations(); + System.out.printf("iterations %s: %d\n", ais[currentAI].getName(), lastIterations); + } } if (match.getWinner() < 0) { diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index cfbf43f..bf65ab0 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -15,13 +15,11 @@ import org.toop.app.widget.complex.PlayerInfoWidget; import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; -import org.toop.game.players.ai.MCTSAI1; -import org.toop.game.players.ai.MCTSAI2; -import org.toop.game.players.ai.MCTSAI3; -import org.toop.game.players.ai.MCTSAI4; -import org.toop.game.players.ai.MCTSAI5; +import org.toop.game.players.ai.mcts.MCTSAI1; +import org.toop.game.players.ai.mcts.MCTSAI2; +import org.toop.game.players.ai.mcts.MCTSAI3; +import org.toop.game.players.ai.mcts.MCTSAI4; import org.toop.game.players.ai.MiniMaxAI; -import org.toop.game.players.ai.RandomAI; import org.toop.local.AppContext; import javafx.geometry.Pos; @@ -87,12 +85,12 @@ public class LocalMultiplayerView extends ViewWidget { if (information.players[0].isHuman) { players[0] = new LocalPlayer<>(information.players[0].name); } else { - players[0] = new ArtificialPlayer<>(new MCTSAI4(100, 3), "MCTS V4 AI"); + players[0] = new ArtificialPlayer<>(new MCTSAI4(1000, 4), "MCTS V4 AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer<>(information.players[1].name); } else { - players[1] = new ArtificialPlayer<>(new MCTSAI5(100, 3), "MCTS V5 AI"); + players[1] = new ArtificialPlayer<>(new MCTSAI2(1000), "MCTS V2 AI"); } if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { new ShowEnableTutorialWidget( diff --git a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java index ed7283c..47a035f 100644 --- a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java +++ b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java @@ -6,8 +6,8 @@ import org.toop.framework.gameFramework.model.player.Player; import org.toop.game.BitboardGame; public class BitboardReversi extends BitboardGame { - private final long notAFile = 0xfefefefefefefefeL; - private final long notHFile = 0x7f7f7f7f7f7f7f7fL; + private static final long notAFile = 0xfefefefefefefefeL; + private static final long notHFile = 0x7f7f7f7f7f7f7f7fL; public BitboardReversi(Player[] players) { super(8, 8, 2, players); diff --git a/game/src/main/java/org/toop/game/players/ArtificialPlayer.java b/game/src/main/java/org/toop/game/players/ArtificialPlayer.java index c3df033..28600bd 100644 --- a/game/src/main/java/org/toop/game/players/ArtificialPlayer.java +++ b/game/src/main/java/org/toop/game/players/ArtificialPlayer.java @@ -52,4 +52,8 @@ public class ArtificialPlayer> extends AbstractPlayer public ArtificialPlayer deepCopy() { return new ArtificialPlayer<>(this); } + + public AI getAi() { + return ai; + } } diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java similarity index 64% rename from game/src/main/java/org/toop/game/players/ai/MCTSAI2.java rename to game/src/main/java/org/toop/game/players/ai/MCTSAI.java index 8c616a6..021b504 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI2.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java @@ -5,23 +5,22 @@ import org.toop.framework.gameFramework.model.player.AbstractAI; import java.util.Random; -public class MCTSAI2> extends AbstractAI { - private static class Node { +public abstract class MCTSAI> extends AbstractAI { + protected static class Node { public TurnBasedGame state; public long move; public long unexpandedMoves; public Node parent; - public Node[] children; - public int expanded; public float value; public int visits; - public boolean solved; - public float solvedValue; + public float heuristic; + + public float solved; public Node(TurnBasedGame state, Node parent, long move) { final long legalMoves = state.getLegalMoves(); @@ -32,23 +31,26 @@ public class MCTSAI2> extends AbstractAI { this.unexpandedMoves = legalMoves; this.parent = parent; - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; this.value = 0.0f; this.visits = 0; - this.solved = false; - this.solvedValue = 0.0f; + this.heuristic = state.rateMove(move); + + this.solved = Float.NaN; } public Node(TurnBasedGame state) { this(state, null, 0L); } + public int getExpanded() { + return children.length - Long.bitCount(unexpandedMoves); + } + public boolean isFullyExpanded() { - return expanded == children.length; + return unexpandedMoves == 0L; } public float calculateUCT(int parentVisits) { @@ -57,12 +59,15 @@ public class MCTSAI2> extends AbstractAI { } final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); + final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits)); + final float bias = heuristic * 10.0f / (visits + 1); - return exploitation + exploration; + return exploitation + exploration + bias; } public Node bestUCTChild() { + final int expanded = getExpanded(); + Node highestUCTChild = null; float highestUCT = Float.NEGATIVE_INFINITY; @@ -79,116 +84,34 @@ public class MCTSAI2> extends AbstractAI { } } - private static final Random random = new Random(); + protected static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - private final int milliseconds; + protected final int milliseconds; - private Node root; + protected int lastIterations; - public MCTSAI2(int milliseconds) { + public MCTSAI(int milliseconds) { this.milliseconds = milliseconds; - - this.root = null; } - public MCTSAI2(MCTSAI2 other) { + public MCTSAI(MCTSAI other) { this.milliseconds = other.milliseconds; - - this.root = other.root; } - @Override - public MCTSAI2 deepCopy() { - return new MCTSAI2<>(this); + public int getLastIterations() { + return lastIterations; } - @Override - public long getMove(T game) { - root = findOrResetRoot(root, game); - - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild.move; - - root = findChildByMove(root, move); - - return move; - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node findOrResetRoot(Node root, T game) { - if (root == null) { - return new Node(game.deepCopy()); - } - - if (areStatesEqual(root.state.getBoard(), game.getBoard())) { - return root; - } - - for (int i = 0; i < root.expanded; i++) { - if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return new Node(game.deepCopy()); - } - - private Node findChildByMove(Node root, long move) { - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].move == move) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return null; - } - - private boolean areStatesEqual(long[] state1, long[] state2) { - if (state1.length != state2.length) { - return false; - } - - for (int i = 0; i < state1.length; i++) { - if (state1[i] != state2[i]) { - return false; - } - } - - return true; - } - private Node selection(Node root) { - while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { + protected Node selection(Node root) { + // while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) { + while (root.isFullyExpanded() && !root.state.isTerminal()) { root = root.bestUCTChild(); } return root; } - private Node expansion(Node leaf) { + protected Node expansion(Node leaf) { if (leaf.unexpandedMoves == 0L) { return leaf; } @@ -200,15 +123,13 @@ public class MCTSAI2> extends AbstractAI { final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - + leaf.children[leaf.getExpanded()] = expandedChild; leaf.unexpandedMoves &= ~unexpandedMove; return expandedChild; } - private float simulation(Node leaf) { + protected float simulation(Node leaf) { final TurnBasedGame copiedState = leaf.state.deepCopy(); final int playerIndex = 1 - copiedState.getCurrentTurn(); @@ -230,12 +151,12 @@ public class MCTSAI2> extends AbstractAI { return 0.0f; } - private void backPropagation(Node leaf, float value) { + protected void backPropagation(Node leaf, float value) { while (leaf != null) { leaf.value += value; leaf.visits++; - if (!leaf.solved) { + if (Float.isNaN(leaf.solved)) { updateSolvedStatus(leaf); } @@ -244,14 +165,91 @@ public class MCTSAI2> extends AbstractAI { } } + protected Node mostVisitedChild(Node root) { + final int expanded = root.getExpanded(); + + Node mostVisitedChild = null; + int mostVisited = -1; + + for (int i = 0; i < expanded; i++) { + if (root.children[i].visits > mostVisited) { + mostVisitedChild = root.children[i]; + mostVisited = root.children[i].visits; + } + } + + return mostVisitedChild; + } + + protected Node findOrResetRoot(Node root, T game) { + if (root == null) { + return new Node(game.deepCopy()); + } + + if (areStatesEqual(root.state.getBoard(), game.getBoard())) { + return root; + } + + final int expanded = root.getExpanded(); + + for (int i = 0; i < expanded; i++) { + if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return new Node(game.deepCopy()); + } + + protected Node findChildByMove(Node root, long move) { + final int expanded = root.getExpanded(); + + for (int i = 0; i < expanded; i++) { + if (root.children[i].move == move) { + root.children[i].parent = null; + return root.children[i]; + } + } + + return null; + } + + protected boolean areStatesEqual(long[] state1, long[] state2) { + if (state1.length != state2.length) { + return false; + } + + for (int i = 0; i < state1.length; i++) { + if (state1[i] != state2[i]) { + return false; + } + } + + return true; + } + + protected long randomSetBit(long value) { + if (0L == value) { + return 0; + } + + final int bitCount = Long.bitCount(value); + final int randomBitCount = random.get().nextInt(bitCount); + + for (int i = 0; i < randomBitCount; i++) { + value &= value - 1; + } + + return value & -value; + } + private void updateSolvedStatus(Node node) { if (node.state.isTerminal()) { - node.solved = true; - final int winner = node.state.getWinner(); final int mover = 1 - node.state.getCurrentTurn(); - node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; + node.solved = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; return; } @@ -262,13 +260,13 @@ public class MCTSAI2> extends AbstractAI { boolean foundDrawMove = false; for (final Node child : node.children) { - if (child.solved) { - if (child.solvedValue == -1.0f) { + if (!Float.isNaN(child.solved)) { + if (child.solved == -1.0f) { foundWinningMove = true; break; } - if (child.solvedValue == 0.0f) { + if (child.solved == 0.0f) { foundDrawMove = true; } } else { @@ -277,27 +275,10 @@ public class MCTSAI2> extends AbstractAI { } if (foundWinningMove) { - node.solved = true; - node.solvedValue = 1.0f; + node.solved = 1.0f; } else if (allChildrenSolved) { - node.solved = true; - node.solvedValue = foundDrawMove? 0.0f : -1.0f; + node.solved = foundDrawMove? 0.0f : -1.0f; } } } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java deleted file mode 100644 index 6621db4..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI1.java +++ /dev/null @@ -1,250 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.Random; - -public class MCTSAI1> extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public boolean solved; - public float solvedValue; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - - this.solved = false; - this.solvedValue = 0.0f; - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - if (visits == 0) { - return Float.POSITIVE_INFINITY; - } - - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - } - - return highestUCTChild; - } - } - - private static final Random random = new Random(); - - private final int milliseconds; - - public MCTSAI1(int milliseconds) { - this.milliseconds = milliseconds; - } - - public MCTSAI1(MCTSAI1 other) { - this.milliseconds = other.milliseconds; - } - - @Override - public MCTSAI1 deepCopy() { - return new MCTSAI1<>(this); - } - - @Override - public long getMove(T game) { - final Node root = new Node(game, null, 0L); - - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - while (System.nanoTime() < endTime) { - Node leaf = selection(root); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - final Node mostVisitedChild = mostVisitedChild(root); - return mostVisitedChild.move; - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node selection(Node root) { - while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - final long randomMove = randomSetBit(legalMoves); - - copiedState.play(randomMove); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } - - if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - if (!leaf.solved) { - updateSolvedStatus(leaf); - } - - value = -value; - leaf = leaf.parent; - } - } - - private void updateSolvedStatus(Node node) { - if (node.state.isTerminal()) { - node.solved = true; - - final int winner = node.state.getWinner(); - final int mover = 1 - node.state.getCurrentTurn(); - - node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; - - return; - } - - if (node.isFullyExpanded()) { - boolean allChildrenSolved = true; - boolean foundWinningMove = false; - boolean foundDrawMove = false; - - for (final Node child : node.children) { - if (child.solved) { - if (child.solvedValue == -1.0f) { - foundWinningMove = true; - break; - } - - if (child.solvedValue == 0.0f) { - foundDrawMove = true; - } - } else { - allChildrenSolved = false; - } - } - - if (foundWinningMove) { - node.solved = true; - node.solvedValue = 1.0f; - } else if (allChildrenSolved) { - node.solved = true; - node.solvedValue = foundDrawMove? 0.0f : -1.0f; - } - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java deleted file mode 100644 index 2d85173..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI3.java +++ /dev/null @@ -1,297 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -public class MCTSAI3> extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public boolean solved; - public float solvedValue; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - - this.solved = false; - this.solvedValue = 0.0f; - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - if (visits == 0) { - return Float.POSITIVE_INFINITY; - } - - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - } - - return highestUCTChild; - } - } - - private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - - private final int milliseconds; - private final int threads; - - public MCTSAI3(int milliseconds, int threads) { - this.milliseconds = milliseconds; - this.threads = threads; - } - - public MCTSAI3(MCTSAI3 other) { - this.milliseconds = other.milliseconds; - this.threads = other.threads; - } - - @Override - public MCTSAI3 deepCopy() { - return new MCTSAI3<>(this); - } - - @Override - public long getMove(T game) { - final ExecutorService pool = Executors.newFixedThreadPool(threads); - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - final List> tasks = new ArrayList<>(); - - for (int i = 0; i < threads; i++) { - tasks.add(() -> { - final Node localRoot = new Node(game.deepCopy()); - - while (System.nanoTime() < endTime) { - Node leaf = selection(localRoot); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - return localRoot; - }); - } - - try { - final List> results = pool.invokeAll(tasks); - - pool.shutdown(); - - final Node root = new Node(game.deepCopy()); - - for (int i = 0; i < root.children.length; i++) { - expansion(root); - } - - for (final Future result : results) { - final Node localRoot = result.get(); - - for (final Node localChild : localRoot.children) { - for (int i = 0; i < root.children.length; i++) { - if (localChild.move == root.children[i].move) { - root.children[i].visits += localChild.visits; - root.visits += localChild.visits; - break; - } - } - } - } - - final Node mostVisitedChild = mostVisitedChild(root); - return mostVisitedChild.move; - } catch (Exception _) { - final long legalMoves = game.getLegalMoves(); - return randomSetBit(legalMoves); - } - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node selection(Node root) { - while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - final long randomMove = randomSetBit(legalMoves); - - copiedState.play(randomMove); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } - - if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - if (!leaf.solved) { - updateSolvedStatus(leaf); - } - - value = -value; - leaf = leaf.parent; - } - } - - private void updateSolvedStatus(Node node) { - if (node.state.isTerminal()) { - node.solved = true; - - final int winner = node.state.getWinner(); - final int mover = 1 - node.state.getCurrentTurn(); - - node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; - - return; - } - - if (node.isFullyExpanded()) { - boolean allChildrenSolved = true; - boolean foundWinningMove = false; - boolean foundDrawMove = false; - - for (final Node child : node.children) { - if (child.solved) { - if (child.solvedValue == -1.0f) { - foundWinningMove = true; - break; - } - - if (child.solvedValue == 0.0f) { - foundDrawMove = true; - } - } else { - allChildrenSolved = false; - } - } - - if (foundWinningMove) { - node.solved = true; - node.solvedValue = 1.0f; - } else if (allChildrenSolved) { - node.solved = true; - node.solvedValue = foundDrawMove? 0.0f : -1.0f; - } - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.get().nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java deleted file mode 100644 index 262d9ae..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI4.java +++ /dev/null @@ -1,359 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -public class MCTSAI4> extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public boolean solved; - public float solvedValue; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - - this.solved = false; - this.solvedValue = 0.0f; - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - if (visits == 0) { - return Float.POSITIVE_INFINITY; - } - - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - - return exploitation + exploration; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - } - - return highestUCTChild; - } - } - - private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - - private final int milliseconds; - private final int threads; - - private final Node[] threadRoots; - - public MCTSAI4(int milliseconds, int threads) { - this.milliseconds = milliseconds; - this.threads = threads; - - this.threadRoots = new Node[threads]; - } - - public MCTSAI4(MCTSAI4 other) { - this.milliseconds = other.milliseconds; - this.threads = other.threads; - - this.threadRoots = other.threadRoots; - } - - @Override - public MCTSAI4 deepCopy() { - return new MCTSAI4<>(this); - } - - @Override - public long getMove(T game) { - for (int i = 0; i < threads; i++) { - threadRoots[i] = findOrResetRoot(threadRoots[i], game); - } - - final ExecutorService pool = Executors.newFixedThreadPool(threads); - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - final List> tasks = new ArrayList<>(); - - for (int i = 0; i < threads; i++) { - final int threadIndex = i; - - tasks.add(() -> { - final Node localRoot = threadRoots[threadIndex]; - - while (System.nanoTime() < endTime) { - Node leaf = selection(localRoot); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - return localRoot; - }); - } - - try { - final List> results = pool.invokeAll(tasks); - - pool.shutdown(); - - final Node root = new Node(game.deepCopy()); - - for (int i = 0; i < root.children.length; i++) { - expansion(root); - } - - for (final Future result : results) { - final Node localRoot = result.get(); - - for (final Node localChild : localRoot.children) { - for (int i = 0; i < root.children.length; i++) { - if (localChild.move == root.children[i].move) { - root.children[i].visits += localChild.visits; - root.visits += localChild.visits; - break; - } - } - } - } - - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild.move; - - for (int i = 0; i < threads; i++) { - threadRoots[i] = findChildByMove(threadRoots[i], move); - } - - return move; - } catch (Exception _) { - final long legalMoves = game.getLegalMoves(); - return randomSetBit(legalMoves); - } - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node findOrResetRoot(Node root, T game) { - if (root == null) { - return new Node(game.deepCopy()); - } - - if (areStatesEqual(root.state.getBoard(), game.getBoard())) { - return root; - } - - for (int i = 0; i < root.expanded; i++) { - if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return new Node(game.deepCopy()); - } - - private Node findChildByMove(Node root, long move) { - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].move == move) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return null; - } - - private boolean areStatesEqual(long[] state1, long[] state2) { - if (state1.length != state2.length) { - return false; - } - - for (int i = 0; i < state1.length; i++) { - if (state1[i] != state2[i]) { - return false; - } - } - - return true; - } - - private Node selection(Node root) { - while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - final long randomMove = randomSetBit(legalMoves); - - copiedState.play(randomMove); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } - - if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - if (!leaf.solved) { - updateSolvedStatus(leaf); - } - - value = -value; - leaf = leaf.parent; - } - } - - private void updateSolvedStatus(Node node) { - if (node.state.isTerminal()) { - node.solved = true; - - final int winner = node.state.getWinner(); - final int mover = 1 - node.state.getCurrentTurn(); - - node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; - - return; - } - - if (node.isFullyExpanded()) { - boolean allChildrenSolved = true; - boolean foundWinningMove = false; - boolean foundDrawMove = false; - - for (final Node child : node.children) { - if (child.solved) { - if (child.solvedValue == -1.0f) { - foundWinningMove = true; - break; - } - - if (child.solvedValue == 0.0f) { - foundDrawMove = true; - } - } else { - allChildrenSolved = false; - } - } - - if (foundWinningMove) { - node.solved = true; - node.solvedValue = 1.0f; - } else if (allChildrenSolved) { - node.solved = true; - node.solvedValue = foundDrawMove? 0.0f : -1.0f; - } - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.get().nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java deleted file mode 100644 index 452eaf8..0000000 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI5.java +++ /dev/null @@ -1,371 +0,0 @@ -package org.toop.game.players.ai; - -import org.toop.framework.gameFramework.model.game.TurnBasedGame; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -public class MCTSAI5> extends AbstractAI { - private static class Node { - public TurnBasedGame state; - - public long move; - public long unexpandedMoves; - - public Node parent; - - public Node[] children; - public int expanded; - - public float value; - public int visits; - - public boolean solved; - public float solvedValue; - - public float heuristic; - - public Node(TurnBasedGame state, Node parent, long move) { - final long legalMoves = state.getLegalMoves(); - - this.state = state; - - this.move = move; - this.unexpandedMoves = legalMoves; - - this.parent = parent; - - this.children = new Node[Long.bitCount(legalMoves)]; - this.expanded = 0; - - this.value = 0.0f; - this.visits = 0; - - this.solved = false; - this.solvedValue = 0.0f; - - this.heuristic = state.rateMove(move); - } - - public Node(TurnBasedGame state) { - this(state, null, 0L); - } - - public boolean isFullyExpanded() { - return expanded == children.length; - } - - public float calculateUCT(int parentVisits) { - if (visits == 0) { - return Float.POSITIVE_INFINITY; - } - - final float exploitation = value / visits; - final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - final float bias = heuristic / visits; - - return exploitation + exploration + bias; - } - - public Node bestUCTChild() { - Node highestUCTChild = null; - float highestUCT = Float.NEGATIVE_INFINITY; - - for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); - - if (childUCT > highestUCT) { - highestUCTChild = children[i]; - highestUCT = childUCT; - } - } - - return highestUCTChild; - } - } - - private static final ThreadLocal random = ThreadLocal.withInitial(Random::new); - - private final int milliseconds; - private final int threads; - - private final Node[] threadRoots; - - public MCTSAI5(int milliseconds, int threads) { - this.milliseconds = milliseconds; - this.threads = threads; - - this.threadRoots = new Node[threads]; - } - - public MCTSAI5(MCTSAI5 other) { - this.milliseconds = other.milliseconds; - this.threads = other.threads; - - this.threadRoots = other.threadRoots; - } - - @Override - public MCTSAI5 deepCopy() { - return new MCTSAI5<>(this); - } - - @Override - public long getMove(T game) { - for (int i = 0; i < threads; i++) { - threadRoots[i] = findOrResetRoot(threadRoots[i], game); - } - - final ExecutorService pool = Executors.newFixedThreadPool(threads); - final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - - final List> tasks = new ArrayList<>(); - - for (int i = 0; i < threads; i++) { - final int threadIndex = i; - - tasks.add(() -> { - final Node localRoot = threadRoots[threadIndex]; - - while (System.nanoTime() < endTime) { - Node leaf = selection(localRoot); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - return localRoot; - }); - } - - try { - final List> results = pool.invokeAll(tasks); - - pool.shutdown(); - - final Node root = new Node(game.deepCopy()); - - for (int i = 0; i < root.children.length; i++) { - expansion(root); - } - - for (final Future result : results) { - final Node localRoot = result.get(); - - for (final Node localChild : localRoot.children) { - for (int i = 0; i < root.children.length; i++) { - if (localChild.move == root.children[i].move) { - root.children[i].visits += localChild.visits; - root.visits += localChild.visits; - break; - } - } - } - } - - final Node mostVisitedChild = mostVisitedChild(root); - final long move = mostVisitedChild.move; - - for (int i = 0; i < threads; i++) { - threadRoots[i] = findChildByMove(threadRoots[i], move); - } - - return move; - } catch (Exception _) { - final long legalMoves = game.getLegalMoves(); - return randomSetBit(legalMoves); - } - } - - private Node mostVisitedChild(Node root) { - Node mostVisitedChild = null; - int mostVisited = -1; - - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].visits > mostVisited) { - mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; - } - } - - return mostVisitedChild; - } - - private Node findOrResetRoot(Node root, T game) { - if (root == null) { - return new Node(game.deepCopy()); - } - - if (areStatesEqual(root.state.getBoard(), game.getBoard())) { - return root; - } - - for (int i = 0; i < root.expanded; i++) { - if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return new Node(game.deepCopy()); - } - - private Node findChildByMove(Node root, long move) { - for (int i = 0; i < root.expanded; i++) { - if (root.children[i].move == move) { - root.children[i].parent = null; - return root.children[i]; - } - } - - return null; - } - - private boolean areStatesEqual(long[] state1, long[] state2) { - if (state1.length != state2.length) { - return false; - } - - for (int i = 0; i < state1.length; i++) { - if (state1[i] != state2[i]) { - return false; - } - } - - return true; - } - - private Node selection(Node root) { - while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) { - root = root.bestUCTChild(); - } - - return root; - } - - private Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; - } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.expanded] = expandedChild; - leaf.expanded++; - - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; - } - - private float simulation(Node leaf) { - final TurnBasedGame copiedState = leaf.state.deepCopy(); - final int playerIndex = 1 - copiedState.getCurrentTurn(); - - while (!copiedState.isTerminal()) { - final long legalMoves = copiedState.getLegalMoves(); - - long move = 0L; - - if (random.get().nextFloat() > 0.9f) { - move = copiedState.heuristicMove(legalMoves); - } else { - move = randomSetBit(legalMoves); - } - - copiedState.play(move); - } - - if (copiedState.getWinner() == playerIndex) { - return 1.0f; - } - - if (copiedState.getWinner() >= 0) { - return -1.0f; - } - - return 0.0f; - } - - private void backPropagation(Node leaf, float value) { - while (leaf != null) { - leaf.value += value; - leaf.visits++; - - if (!leaf.solved) { - updateSolvedStatus(leaf); - } - - value = -value; - leaf = leaf.parent; - } - } - - private void updateSolvedStatus(Node node) { - if (node.state.isTerminal()) { - node.solved = true; - - final int winner = node.state.getWinner(); - final int mover = 1 - node.state.getCurrentTurn(); - - node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; - - return; - } - - if (node.isFullyExpanded()) { - boolean allChildrenSolved = true; - boolean foundWinningMove = false; - boolean foundDrawMove = false; - - for (final Node child : node.children) { - if (child.solved) { - if (child.solvedValue == -1.0f) { - foundWinningMove = true; - break; - } - - if (child.solvedValue == 0.0f) { - foundDrawMove = true; - } - } else { - allChildrenSolved = false; - } - } - - if (foundWinningMove) { - node.solved = true; - node.solvedValue = 1.0f; - } else if (allChildrenSolved) { - node.solved = true; - node.solvedValue = foundDrawMove? 0.0f : -1.0f; - } - } - } - - private long randomSetBit(long value) { - if (0L == value) { - return 0; - } - - final int bitCount = Long.bitCount(value); - final int randomBitCount = random.get().nextInt(bitCount); - - for (int i = 0; i < randomBitCount; i++) { - value &= value - 1; - } - - return value & -value; - } -} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java new file mode 100644 index 0000000..c52472a --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java @@ -0,0 +1,39 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +public class MCTSAI1> extends MCTSAI { + public MCTSAI1(int milliseconds) { + super(milliseconds); + } + + public MCTSAI1(MCTSAI1 other) { + super(other); + } + + @Override + public MCTSAI1 deepCopy() { + return new MCTSAI1<>(this); + } + + @Override + public long getMove(T game) { + final Node root = new Node(game, null, 0L); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java new file mode 100644 index 0000000..aabafb6 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java @@ -0,0 +1,49 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +public class MCTSAI2> extends MCTSAI { + private Node root; + + public MCTSAI2(int milliseconds) { + super(milliseconds); + + this.root = null; + } + + public MCTSAI2(MCTSAI2 other) { + super(other); + + this.root = other.root; + } + + @Override + public MCTSAI2 deepCopy() { + return new MCTSAI2<>(this); + } + + @Override + public long getMove(T game) { + root = findOrResetRoot(root, game); + + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + root = findChildByMove(root, move); + + return move; + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java new file mode 100644 index 0000000..ab86b11 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java @@ -0,0 +1,91 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI3> extends MCTSAI { + private final int threads; + + public MCTSAI3(int milliseconds, int threads) { + super(milliseconds); + + this.threads = threads; + } + + public MCTSAI3(MCTSAI3 other) { + super(other); + + this.threads = other.threads; + } + + @Override + public MCTSAI3 deepCopy() { + return new MCTSAI3<>(this); + } + + @Override + public long getMove(T game) { + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + tasks.add(() -> { + final Node localRoot = new Node(game.deepCopy()); + + while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + return mostVisitedChild.move; + } catch (Exception _) { + lastIterations = 0; + + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } +} \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java new file mode 100644 index 0000000..ff2fb59 --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java @@ -0,0 +1,107 @@ +package org.toop.game.players.ai.mcts; + +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.game.players.ai.MCTSAI; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +public class MCTSAI4> extends MCTSAI { + private final int threads; + private final Node[] threadRoots; + + public MCTSAI4(int milliseconds, int threads) { + super(milliseconds); + + this.threads = threads; + this.threadRoots = new Node[threads]; + } + + public MCTSAI4(MCTSAI4 other) { + super(other); + + this.threads = other.threads; + this.threadRoots = other.threadRoots; + } + + @Override + public MCTSAI4 deepCopy() { + return new MCTSAI4<>(this); + } + + @Override + public long getMove(T game) { + for (int i = 0; i < threads; i++) { + threadRoots[i] = findOrResetRoot(threadRoots[i], game); + } + + final ExecutorService pool = Executors.newFixedThreadPool(threads); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + + final List> tasks = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + final int threadIndex = i; + + tasks.add(() -> { + final Node localRoot = threadRoots[threadIndex]; + + // while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { + while (System.nanoTime() < endTime) { + Node leaf = selection(localRoot); + leaf = expansion(leaf); + final float value = simulation(leaf); + backPropagation(leaf, value); + } + + return localRoot; + }); + } + + try { + final List> results = pool.invokeAll(tasks); + + pool.shutdown(); + + final Node root = new Node(game.deepCopy()); + + for (int i = 0; i < root.children.length; i++) { + expansion(root); + } + + for (final Future result : results) { + final Node localRoot = result.get(); + + for (final Node localChild : localRoot.children) { + for (int i = 0; i < root.children.length; i++) { + if (localChild.move == root.children[i].move) { + root.children[i].visits += localChild.visits; + root.visits += localChild.visits; + break; + } + } + } + } + + lastIterations = root.visits; + + final Node mostVisitedChild = mostVisitedChild(root); + final long move = mostVisitedChild.move; + + for (int i = 0; i < threads; i++) { + threadRoots[i] = findChildByMove(threadRoots[i], move); + } + + return move; + } catch (Exception _) { + lastIterations = 0; + + final long legalMoves = game.getLegalMoves(); + return randomSetBit(legalMoves); + } + } +} \ No newline at end of file