update mcts

This commit is contained in:
ramollia
2026-01-16 12:10:45 +01:00
parent a6f5f2c854
commit c54b2a19e2
13 changed files with 437 additions and 1440 deletions

View File

@@ -3,39 +3,39 @@ package org.toop;
import org.toop.app.App; import org.toop.app.App;
import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.gameFramework.model.player.Player;
import org.toop.game.games.reversi.BitboardReversi; import org.toop.game.games.reversi.BitboardReversi;
import org.toop.game.games.tictactoe.BitboardTicTacToe;
import org.toop.game.players.ArtificialPlayer; import org.toop.game.players.ArtificialPlayer;
import org.toop.game.players.ai.MCTSAI1; import org.toop.game.players.ai.MCTSAI;
import org.toop.game.players.ai.MCTSAI2; import org.toop.game.players.ai.RandomAI;
import org.toop.game.players.ai.MCTSAI3; import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.MCTSAI4; import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.MCTSAI5; import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4;
public final class Main { public final class Main {
static void main(String[] args) { static void main(String[] args) {
App.run(args); // App.run(args);
// testMCTS(100); testMCTS(25);
} }
private static void testMCTS(int games) { private static void testMCTS(int games) {
var versions = new ArtificialPlayer[5]; var versions = new ArtificialPlayer[5];
versions[0] = new ArtificialPlayer<>(new MCTSAI1<BitboardTicTacToe>(10), "MCTS V1 AI"); versions[0] = new ArtificialPlayer<>(new RandomAI<BitboardReversi>(), "Random AI");
versions[1] = new ArtificialPlayer<>(new MCTSAI2<BitboardTicTacToe>(10), "MCTS V2 AI"); versions[1] = new ArtificialPlayer<>(new MCTSAI1<BitboardReversi>(1000), "MCTS V1 AI");
versions[2] = new ArtificialPlayer<>(new MCTSAI3<BitboardTicTacToe>(10, 10), "MCTS V3 AI"); versions[2] = new ArtificialPlayer<>(new MCTSAI2<BitboardReversi>(1000), "MCTS V2 AI");
versions[3] = new ArtificialPlayer<>(new MCTSAI4<BitboardTicTacToe>(10, 10), "MCTS V4 AI"); versions[3] = new ArtificialPlayer<>(new MCTSAI3<BitboardReversi>(10, 10), "MCTS V3 AI");
versions[4] = new ArtificialPlayer<>(new MCTSAI5<BitboardTicTacToe>(10, 10), "MCTS V5 AI"); versions[4] = new ArtificialPlayer<>(new MCTSAI4<BitboardReversi>(10, 10), "MCTS V4 AI");
for (int i = 2; i < versions.length; i++) { for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) { for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length; final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length; final int playerIndex2 = j % versions.length;
testAI(games, new Player[] { versions[playerIndex1], versions[playerIndex2]}); testAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]});
} }
} }
} }
private static void testAI(int games, Player<BitboardReversi>[] ais) { private static void testAI(int games, ArtificialPlayer<BitboardReversi>[] ais) {
int wins = 0; int wins = 0;
int ties = 0; int ties = 0;
@@ -47,6 +47,11 @@ public final class Main {
final long move = ais[currentAI].getMove(match); final long move = ais[currentAI].getMove(match);
match.play(move); match.play(move);
if (ais[currentAI].getAi() instanceof MCTSAI<?> mcts) {
final int lastIterations = mcts.getLastIterations();
System.out.printf("iterations %s: %d\n", ais[currentAI].getName(), lastIterations);
}
} }
if (match.getWinner() < 0) { if (match.getWinner() < 0) {

View File

@@ -15,13 +15,11 @@ import org.toop.app.widget.complex.PlayerInfoWidget;
import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.complex.ViewWidget;
import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.popup.ErrorPopup;
import org.toop.app.widget.tutorial.*; import org.toop.app.widget.tutorial.*;
import org.toop.game.players.ai.MCTSAI1; import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.MCTSAI2; import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.MCTSAI3; import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.MCTSAI4; import org.toop.game.players.ai.mcts.MCTSAI4;
import org.toop.game.players.ai.MCTSAI5;
import org.toop.game.players.ai.MiniMaxAI; import org.toop.game.players.ai.MiniMaxAI;
import org.toop.game.players.ai.RandomAI;
import org.toop.local.AppContext; import org.toop.local.AppContext;
import javafx.geometry.Pos; import javafx.geometry.Pos;
@@ -87,12 +85,12 @@ public class LocalMultiplayerView extends ViewWidget {
if (information.players[0].isHuman) { if (information.players[0].isHuman) {
players[0] = new LocalPlayer<>(information.players[0].name); players[0] = new LocalPlayer<>(information.players[0].name);
} else { } else {
players[0] = new ArtificialPlayer<>(new MCTSAI4<BitboardReversi>(100, 3), "MCTS V4 AI"); players[0] = new ArtificialPlayer<>(new MCTSAI4<BitboardReversi>(1000, 4), "MCTS V4 AI");
} }
if (information.players[1].isHuman) { if (information.players[1].isHuman) {
players[1] = new LocalPlayer<>(information.players[1].name); players[1] = new LocalPlayer<>(information.players[1].name);
} else { } else {
players[1] = new ArtificialPlayer<>(new MCTSAI5<BitboardReversi>(100, 3), "MCTS V5 AI"); players[1] = new ArtificialPlayer<>(new MCTSAI2<BitboardReversi>(1000), "MCTS V2 AI");
} }
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
new ShowEnableTutorialWidget( new ShowEnableTutorialWidget(

View File

@@ -6,8 +6,8 @@ import org.toop.framework.gameFramework.model.player.Player;
import org.toop.game.BitboardGame; import org.toop.game.BitboardGame;
public class BitboardReversi extends BitboardGame<BitboardReversi> { public class BitboardReversi extends BitboardGame<BitboardReversi> {
private final long notAFile = 0xfefefefefefefefeL; private static final long notAFile = 0xfefefefefefefefeL;
private final long notHFile = 0x7f7f7f7f7f7f7f7fL; private static final long notHFile = 0x7f7f7f7f7f7f7f7fL;
public BitboardReversi(Player<BitboardReversi>[] players) { public BitboardReversi(Player<BitboardReversi>[] players) {
super(8, 8, 2, players); super(8, 8, 2, players);

View File

@@ -52,4 +52,8 @@ public class ArtificialPlayer<T extends TurnBasedGame<T>> extends AbstractPlayer
public ArtificialPlayer<T> deepCopy() { public ArtificialPlayer<T> deepCopy() {
return new ArtificialPlayer<>(this); return new ArtificialPlayer<>(this);
} }
public AI<T> getAi() {
return ai;
}
} }

View File

@@ -5,23 +5,22 @@ import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random; import java.util.Random;
public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> { public abstract class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node { protected static class Node {
public TurnBasedGame<?> state; public TurnBasedGame<?> state;
public long move; public long move;
public long unexpandedMoves; public long unexpandedMoves;
public Node parent; public Node parent;
public Node[] children; public Node[] children;
public int expanded;
public float value; public float value;
public int visits; public int visits;
public boolean solved; public float heuristic;
public float solvedValue;
public float solved;
public Node(TurnBasedGame<?> state, Node parent, long move) { public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves(); final long legalMoves = state.getLegalMoves();
@@ -32,23 +31,26 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
this.unexpandedMoves = legalMoves; this.unexpandedMoves = legalMoves;
this.parent = parent; this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)]; this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f; this.value = 0.0f;
this.visits = 0; this.visits = 0;
this.solved = false; this.heuristic = state.rateMove(move);
this.solvedValue = 0.0f;
this.solved = Float.NaN;
} }
public Node(TurnBasedGame<?> state) { public Node(TurnBasedGame<?> state) {
this(state, null, 0L); this(state, null, 0L);
} }
public int getExpanded() {
return children.length - Long.bitCount(unexpandedMoves);
}
public boolean isFullyExpanded() { public boolean isFullyExpanded() {
return expanded == children.length; return unexpandedMoves == 0L;
} }
public float calculateUCT(int parentVisits) { public float calculateUCT(int parentVisits) {
@@ -57,12 +59,15 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
} }
final float exploitation = value / visits; final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits));
final float bias = heuristic * 10.0f / (visits + 1);
return exploitation + exploration; return exploitation + exploration + bias;
} }
public Node bestUCTChild() { public Node bestUCTChild() {
final int expanded = getExpanded();
Node highestUCTChild = null; Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY; float highestUCT = Float.NEGATIVE_INFINITY;
@@ -79,116 +84,34 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
} }
} }
private static final Random random = new Random(); protected static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
private final int milliseconds; protected final int milliseconds;
private Node root; protected int lastIterations;
public MCTSAI2(int milliseconds) { public MCTSAI(int milliseconds) {
this.milliseconds = milliseconds; this.milliseconds = milliseconds;
this.root = null;
} }
public MCTSAI2(MCTSAI2<T> other) { public MCTSAI(MCTSAI<T> other) {
this.milliseconds = other.milliseconds; this.milliseconds = other.milliseconds;
this.root = other.root;
} }
@Override public int getLastIterations() {
public MCTSAI2<T> deepCopy() { return lastIterations;
return new MCTSAI2<>(this);
} }
@Override protected Node selection(Node root) {
public long getMove(T game) { // while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
root = findOrResetRoot(root, game); while (root.isFullyExpanded() && !root.state.isTerminal()) {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
root = findChildByMove(root, move);
return move;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node findOrResetRoot(Node root, T game) {
if (root == null) {
return new Node(game.deepCopy());
}
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
return root;
}
for (int i = 0; i < root.expanded; i++) {
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
root.children[i].parent = null;
return root.children[i];
}
}
return new Node(game.deepCopy());
}
private Node findChildByMove(Node root, long move) {
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].move == move) {
root.children[i].parent = null;
return root.children[i];
}
}
return null;
}
private boolean areStatesEqual(long[] state1, long[] state2) {
if (state1.length != state2.length) {
return false;
}
for (int i = 0; i < state1.length; i++) {
if (state1[i] != state2[i]) {
return false;
}
}
return true;
}
private Node selection(Node root) {
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild(); root = root.bestUCTChild();
} }
return root; return root;
} }
private Node expansion(Node leaf) { protected Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) { if (leaf.unexpandedMoves == 0L) {
return leaf; return leaf;
} }
@@ -200,15 +123,13 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild; leaf.children[leaf.getExpanded()] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove; leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild; return expandedChild;
} }
private float simulation(Node leaf) { protected float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy(); final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn(); final int playerIndex = 1 - copiedState.getCurrentTurn();
@@ -230,12 +151,12 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
return 0.0f; return 0.0f;
} }
private void backPropagation(Node leaf, float value) { protected void backPropagation(Node leaf, float value) {
while (leaf != null) { while (leaf != null) {
leaf.value += value; leaf.value += value;
leaf.visits++; leaf.visits++;
if (!leaf.solved) { if (Float.isNaN(leaf.solved)) {
updateSolvedStatus(leaf); updateSolvedStatus(leaf);
} }
@@ -244,14 +165,91 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
} }
} }
protected Node mostVisitedChild(Node root) {
final int expanded = root.getExpanded();
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
protected Node findOrResetRoot(Node root, T game) {
if (root == null) {
return new Node(game.deepCopy());
}
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
return root;
}
final int expanded = root.getExpanded();
for (int i = 0; i < expanded; i++) {
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
root.children[i].parent = null;
return root.children[i];
}
}
return new Node(game.deepCopy());
}
protected Node findChildByMove(Node root, long move) {
final int expanded = root.getExpanded();
for (int i = 0; i < expanded; i++) {
if (root.children[i].move == move) {
root.children[i].parent = null;
return root.children[i];
}
}
return null;
}
protected boolean areStatesEqual(long[] state1, long[] state2) {
if (state1.length != state2.length) {
return false;
}
for (int i = 0; i < state1.length; i++) {
if (state1[i] != state2[i]) {
return false;
}
}
return true;
}
protected long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.get().nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
private void updateSolvedStatus(Node node) { private void updateSolvedStatus(Node node) {
if (node.state.isTerminal()) { if (node.state.isTerminal()) {
node.solved = true;
final int winner = node.state.getWinner(); final int winner = node.state.getWinner();
final int mover = 1 - node.state.getCurrentTurn(); final int mover = 1 - node.state.getCurrentTurn();
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f; node.solved = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
return; return;
} }
@@ -262,13 +260,13 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
boolean foundDrawMove = false; boolean foundDrawMove = false;
for (final Node child : node.children) { for (final Node child : node.children) {
if (child.solved) { if (!Float.isNaN(child.solved)) {
if (child.solvedValue == -1.0f) { if (child.solved == -1.0f) {
foundWinningMove = true; foundWinningMove = true;
break; break;
} }
if (child.solvedValue == 0.0f) { if (child.solved == 0.0f) {
foundDrawMove = true; foundDrawMove = true;
} }
} else { } else {
@@ -277,27 +275,10 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
} }
if (foundWinningMove) { if (foundWinningMove) {
node.solved = true; node.solved = 1.0f;
node.solvedValue = 1.0f;
} else if (allChildrenSolved) { } else if (allChildrenSolved) {
node.solved = true; node.solved = foundDrawMove? 0.0f : -1.0f;
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
} }
} }
} }
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
} }

View File

@@ -1,250 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
public class MCTSAI1<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public boolean solved;
public float solvedValue;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
this.solved = false;
this.solvedValue = 0.0f;
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
if (visits == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private static final Random random = new Random();
private final int milliseconds;
public MCTSAI1(int milliseconds) {
this.milliseconds = milliseconds;
}
public MCTSAI1(MCTSAI1<T> other) {
this.milliseconds = other.milliseconds;
}
@Override
public MCTSAI1<T> deepCopy() {
return new MCTSAI1<>(this);
}
@Override
public long getMove(T game) {
final Node root = new Node(game, null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node selection(Node root) {
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
}
if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
if (!leaf.solved) {
updateSolvedStatus(leaf);
}
value = -value;
leaf = leaf.parent;
}
}
private void updateSolvedStatus(Node node) {
if (node.state.isTerminal()) {
node.solved = true;
final int winner = node.state.getWinner();
final int mover = 1 - node.state.getCurrentTurn();
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
return;
}
if (node.isFullyExpanded()) {
boolean allChildrenSolved = true;
boolean foundWinningMove = false;
boolean foundDrawMove = false;
for (final Node child : node.children) {
if (child.solved) {
if (child.solvedValue == -1.0f) {
foundWinningMove = true;
break;
}
if (child.solvedValue == 0.0f) {
foundDrawMove = true;
}
} else {
allChildrenSolved = false;
}
}
if (foundWinningMove) {
node.solved = true;
node.solvedValue = 1.0f;
} else if (allChildrenSolved) {
node.solved = true;
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
}
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -1,297 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public boolean solved;
public float solvedValue;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
this.solved = false;
this.solvedValue = 0.0f;
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
if (visits == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
private final int milliseconds;
private final int threads;
public MCTSAI3(int milliseconds, int threads) {
this.milliseconds = milliseconds;
this.threads = threads;
}
public MCTSAI3(MCTSAI3<T> other) {
this.milliseconds = other.milliseconds;
this.threads = other.threads;
}
@Override
public MCTSAI3<T> deepCopy() {
return new MCTSAI3<>(this);
}
@Override
public long getMove(T game) {
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
tasks.add(() -> {
final Node localRoot = new Node(game.deepCopy());
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
} catch (Exception _) {
final long legalMoves = game.getLegalMoves();
return randomSetBit(legalMoves);
}
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node selection(Node root) {
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
}
if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
if (!leaf.solved) {
updateSolvedStatus(leaf);
}
value = -value;
leaf = leaf.parent;
}
}
private void updateSolvedStatus(Node node) {
if (node.state.isTerminal()) {
node.solved = true;
final int winner = node.state.getWinner();
final int mover = 1 - node.state.getCurrentTurn();
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
return;
}
if (node.isFullyExpanded()) {
boolean allChildrenSolved = true;
boolean foundWinningMove = false;
boolean foundDrawMove = false;
for (final Node child : node.children) {
if (child.solved) {
if (child.solvedValue == -1.0f) {
foundWinningMove = true;
break;
}
if (child.solvedValue == 0.0f) {
foundDrawMove = true;
}
} else {
allChildrenSolved = false;
}
}
if (foundWinningMove) {
node.solved = true;
node.solvedValue = 1.0f;
} else if (allChildrenSolved) {
node.solved = true;
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
}
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.get().nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -1,359 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MCTSAI4<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public boolean solved;
public float solvedValue;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
this.solved = false;
this.solvedValue = 0.0f;
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
if (visits == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
private final int milliseconds;
private final int threads;
private final Node[] threadRoots;
public MCTSAI4(int milliseconds, int threads) {
this.milliseconds = milliseconds;
this.threads = threads;
this.threadRoots = new Node[threads];
}
public MCTSAI4(MCTSAI4<T> other) {
this.milliseconds = other.milliseconds;
this.threads = other.threads;
this.threadRoots = other.threadRoots;
}
@Override
public MCTSAI4<T> deepCopy() {
return new MCTSAI4<>(this);
}
@Override
public long getMove(T game) {
for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move;
} catch (Exception _) {
final long legalMoves = game.getLegalMoves();
return randomSetBit(legalMoves);
}
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node findOrResetRoot(Node root, T game) {
if (root == null) {
return new Node(game.deepCopy());
}
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
return root;
}
for (int i = 0; i < root.expanded; i++) {
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
root.children[i].parent = null;
return root.children[i];
}
}
return new Node(game.deepCopy());
}
private Node findChildByMove(Node root, long move) {
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].move == move) {
root.children[i].parent = null;
return root.children[i];
}
}
return null;
}
private boolean areStatesEqual(long[] state1, long[] state2) {
if (state1.length != state2.length) {
return false;
}
for (int i = 0; i < state1.length; i++) {
if (state1[i] != state2[i]) {
return false;
}
}
return true;
}
private Node selection(Node root) {
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
}
if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
if (!leaf.solved) {
updateSolvedStatus(leaf);
}
value = -value;
leaf = leaf.parent;
}
}
private void updateSolvedStatus(Node node) {
if (node.state.isTerminal()) {
node.solved = true;
final int winner = node.state.getWinner();
final int mover = 1 - node.state.getCurrentTurn();
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
return;
}
if (node.isFullyExpanded()) {
boolean allChildrenSolved = true;
boolean foundWinningMove = false;
boolean foundDrawMove = false;
for (final Node child : node.children) {
if (child.solved) {
if (child.solvedValue == -1.0f) {
foundWinningMove = true;
break;
}
if (child.solvedValue == 0.0f) {
foundDrawMove = true;
}
} else {
allChildrenSolved = false;
}
}
if (foundWinningMove) {
node.solved = true;
node.solvedValue = 1.0f;
} else if (allChildrenSolved) {
node.solved = true;
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
}
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.get().nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -1,371 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MCTSAI5<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public boolean solved;
public float solvedValue;
public float heuristic;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
this.solved = false;
this.solvedValue = 0.0f;
this.heuristic = state.rateMove(move);
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
if (visits == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
final float bias = heuristic / visits;
return exploitation + exploration + bias;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
private final int milliseconds;
private final int threads;
private final Node[] threadRoots;
public MCTSAI5(int milliseconds, int threads) {
this.milliseconds = milliseconds;
this.threads = threads;
this.threadRoots = new Node[threads];
}
public MCTSAI5(MCTSAI5<T> other) {
this.milliseconds = other.milliseconds;
this.threads = other.threads;
this.threadRoots = other.threadRoots;
}
@Override
public MCTSAI5<T> deepCopy() {
return new MCTSAI5<>(this);
}
@Override
public long getMove(T game) {
for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move;
} catch (Exception _) {
final long legalMoves = game.getLegalMoves();
return randomSetBit(legalMoves);
}
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node findOrResetRoot(Node root, T game) {
if (root == null) {
return new Node(game.deepCopy());
}
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
return root;
}
for (int i = 0; i < root.expanded; i++) {
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
root.children[i].parent = null;
return root.children[i];
}
}
return new Node(game.deepCopy());
}
private Node findChildByMove(Node root, long move) {
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].move == move) {
root.children[i].parent = null;
return root.children[i];
}
}
return null;
}
private boolean areStatesEqual(long[] state1, long[] state2) {
if (state1.length != state2.length) {
return false;
}
for (int i = 0; i < state1.length; i++) {
if (state1[i] != state2[i]) {
return false;
}
}
return true;
}
private Node selection(Node root) {
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
long move = 0L;
if (random.get().nextFloat() > 0.9f) {
move = copiedState.heuristicMove(legalMoves);
} else {
move = randomSetBit(legalMoves);
}
copiedState.play(move);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
}
if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
if (!leaf.solved) {
updateSolvedStatus(leaf);
}
value = -value;
leaf = leaf.parent;
}
}
private void updateSolvedStatus(Node node) {
if (node.state.isTerminal()) {
node.solved = true;
final int winner = node.state.getWinner();
final int mover = 1 - node.state.getCurrentTurn();
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
return;
}
if (node.isFullyExpanded()) {
boolean allChildrenSolved = true;
boolean foundWinningMove = false;
boolean foundDrawMove = false;
for (final Node child : node.children) {
if (child.solved) {
if (child.solvedValue == -1.0f) {
foundWinningMove = true;
break;
}
if (child.solvedValue == 0.0f) {
foundDrawMove = true;
}
} else {
allChildrenSolved = false;
}
}
if (foundWinningMove) {
node.solved = true;
node.solvedValue = 1.0f;
} else if (allChildrenSolved) {
node.solved = true;
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
}
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.get().nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -0,0 +1,39 @@
package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
public class MCTSAI1<T extends TurnBasedGame<T>> extends MCTSAI<T> {
public MCTSAI1(int milliseconds) {
super(milliseconds);
}
public MCTSAI1(MCTSAI1<T> other) {
super(other);
}
@Override
public MCTSAI1<T> deepCopy() {
return new MCTSAI1<>(this);
}
@Override
public long getMove(T game) {
final Node root = new Node(game, null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
}
}

View File

@@ -0,0 +1,49 @@
package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
public class MCTSAI2<T extends TurnBasedGame<T>> extends MCTSAI<T> {
private Node root;
public MCTSAI2(int milliseconds) {
super(milliseconds);
this.root = null;
}
public MCTSAI2(MCTSAI2<T> other) {
super(other);
this.root = other.root;
}
@Override
public MCTSAI2<T> deepCopy() {
return new MCTSAI2<>(this);
}
@Override
public long getMove(T game) {
root = findOrResetRoot(root, game);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
root = findChildByMove(root, move);
return move;
}
}

View File

@@ -0,0 +1,91 @@
package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MCTSAI3<T extends TurnBasedGame<T>> extends MCTSAI<T> {
private final int threads;
public MCTSAI3(int milliseconds, int threads) {
super(milliseconds);
this.threads = threads;
}
public MCTSAI3(MCTSAI3<T> other) {
super(other);
this.threads = other.threads;
}
@Override
public MCTSAI3<T> deepCopy() {
return new MCTSAI3<>(this);
}
@Override
public long getMove(T game) {
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
tasks.add(() -> {
final Node localRoot = new Node(game.deepCopy());
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
} catch (Exception _) {
lastIterations = 0;
final long legalMoves = game.getLegalMoves();
return randomSetBit(legalMoves);
}
}
}

View File

@@ -0,0 +1,107 @@
package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
public class MCTSAI4<T extends TurnBasedGame<T>> extends MCTSAI<T> {
private final int threads;
private final Node[] threadRoots;
public MCTSAI4(int milliseconds, int threads) {
super(milliseconds);
this.threads = threads;
this.threadRoots = new Node[threads];
}
public MCTSAI4(MCTSAI4<T> other) {
super(other);
this.threads = other.threads;
this.threadRoots = other.threadRoots;
}
@Override
public MCTSAI4<T> deepCopy() {
return new MCTSAI4<>(this);
}
@Override
public long getMove(T game) {
for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
// while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move;
} catch (Exception _) {
lastIterations = 0;
final long legalMoves = game.getLegalMoves();
return randomSetBit(legalMoves);
}
}
}