mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
update mcts
This commit is contained in:
@@ -1,41 +1,41 @@
|
|||||||
package org.toop;
|
package org.toop;
|
||||||
|
|
||||||
import org.toop.app.App;
|
import org.toop.app.App;
|
||||||
import org.toop.framework.gameFramework.model.player.AbstractPlayer;
|
|
||||||
import org.toop.framework.gameFramework.model.player.Player;
|
import org.toop.framework.gameFramework.model.player.Player;
|
||||||
import org.toop.game.games.reversi.BitboardReversi;
|
import org.toop.game.games.reversi.BitboardReversi;
|
||||||
import org.toop.game.games.tictactoe.BitboardTicTacToe;
|
|
||||||
import org.toop.game.players.ArtificialPlayer;
|
import org.toop.game.players.ArtificialPlayer;
|
||||||
import org.toop.game.players.ai.MCTSAI;
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
import org.toop.game.players.ai.MCTSAI2;
|
|
||||||
import org.toop.game.players.ai.MCTSAI3;
|
|
||||||
import org.toop.game.players.ai.RandomAI;
|
import org.toop.game.players.ai.RandomAI;
|
||||||
|
import org.toop.game.players.ai.mcts.MCTSAI1;
|
||||||
|
import org.toop.game.players.ai.mcts.MCTSAI2;
|
||||||
|
import org.toop.game.players.ai.mcts.MCTSAI3;
|
||||||
|
import org.toop.game.players.ai.mcts.MCTSAI4;
|
||||||
|
|
||||||
public final class Main {
|
public final class Main {
|
||||||
static void main(String[] args) {
|
static void main(String[] args) {
|
||||||
App.run(args);
|
// App.run(args);
|
||||||
// testMCTS(100);
|
testMCTS(25);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void testMCTS(int games) {
|
private static void testMCTS(int games) {
|
||||||
var versions = new ArtificialPlayer[5];
|
var versions = new ArtificialPlayer[5];
|
||||||
versions[0] = new ArtificialPlayer<>(new MCTSAI1<BitboardTicTacToe>(10), "MCTS V1 AI");
|
versions[0] = new ArtificialPlayer<>(new RandomAI<BitboardReversi>(), "Random AI");
|
||||||
versions[1] = new ArtificialPlayer<>(new MCTSAI2<BitboardTicTacToe>(10), "MCTS V2 AI");
|
versions[1] = new ArtificialPlayer<>(new MCTSAI1<BitboardReversi>(1000), "MCTS V1 AI");
|
||||||
versions[2] = new ArtificialPlayer<>(new MCTSAI3<BitboardTicTacToe>(10, 10), "MCTS V3 AI");
|
versions[2] = new ArtificialPlayer<>(new MCTSAI2<BitboardReversi>(1000), "MCTS V2 AI");
|
||||||
versions[3] = new ArtificialPlayer<>(new MCTSAI4<BitboardTicTacToe>(10, 10), "MCTS V4 AI");
|
versions[3] = new ArtificialPlayer<>(new MCTSAI3<BitboardReversi>(10, 10), "MCTS V3 AI");
|
||||||
versions[4] = new ArtificialPlayer<>(new MCTSAI5<BitboardTicTacToe>(10, 10), "MCTS V5 AI");
|
versions[4] = new ArtificialPlayer<>(new MCTSAI4<BitboardReversi>(10, 10), "MCTS V4 AI");
|
||||||
|
|
||||||
for (int i = 2; i < versions.length; i++) {
|
for (int i = 0; i < versions.length; i++) {
|
||||||
for (int j = i + 1; j < versions.length; j++) {
|
for (int j = i + 1; j < versions.length; j++) {
|
||||||
final int playerIndex1 = i % versions.length;
|
final int playerIndex1 = i % versions.length;
|
||||||
final int playerIndex2 = j % versions.length;
|
final int playerIndex2 = j % versions.length;
|
||||||
|
|
||||||
testAI(games, new Player[] { versions[playerIndex1], versions[playerIndex2]});
|
testAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void testAI(int games, Player<BitboardReversi>[] ais) {
|
private static void testAI(int games, ArtificialPlayer<BitboardReversi>[] ais) {
|
||||||
int wins = 0;
|
int wins = 0;
|
||||||
int ties = 0;
|
int ties = 0;
|
||||||
|
|
||||||
@@ -47,6 +47,11 @@ public final class Main {
|
|||||||
final long move = ais[currentAI].getMove(match);
|
final long move = ais[currentAI].getMove(match);
|
||||||
|
|
||||||
match.play(move);
|
match.play(move);
|
||||||
|
|
||||||
|
if (ais[currentAI].getAi() instanceof MCTSAI<?> mcts) {
|
||||||
|
final int lastIterations = mcts.getLastIterations();
|
||||||
|
System.out.printf("iterations %s: %d\n", ais[currentAI].getName(), lastIterations);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (match.getWinner() < 0) {
|
if (match.getWinner() < 0) {
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
public long unexpandedMoves;
|
public long unexpandedMoves;
|
||||||
|
|
||||||
public Node parent;
|
public Node parent;
|
||||||
|
|
||||||
public Node[] children;
|
public Node[] children;
|
||||||
public int expanded;
|
|
||||||
|
|
||||||
public float value;
|
public float value;
|
||||||
public int visits;
|
public int visits;
|
||||||
|
|
||||||
|
public float heuristic;
|
||||||
|
|
||||||
|
public float solved;
|
||||||
|
|
||||||
public Node(TurnBasedGame state, Node parent, long move) {
|
public Node(TurnBasedGame state, Node parent, long move) {
|
||||||
final long legalMoves = state.getLegalMoves();
|
final long legalMoves = state.getLegalMoves();
|
||||||
|
|
||||||
@@ -29,23 +31,26 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
this.unexpandedMoves = legalMoves;
|
this.unexpandedMoves = legalMoves;
|
||||||
|
|
||||||
this.parent = parent;
|
this.parent = parent;
|
||||||
|
|
||||||
this.children = new Node[Long.bitCount(legalMoves)];
|
this.children = new Node[Long.bitCount(legalMoves)];
|
||||||
this.expanded = 0;
|
|
||||||
|
|
||||||
this.value = 0.0f;
|
this.value = 0.0f;
|
||||||
this.visits = 0;
|
this.visits = 0;
|
||||||
|
|
||||||
this.solved = false;
|
this.heuristic = state.rateMove(move);
|
||||||
this.solvedValue = 0.0f;
|
|
||||||
|
this.solved = Float.NaN;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node(TurnBasedGame state) {
|
public Node(TurnBasedGame state) {
|
||||||
this(state, null, 0L);
|
this(state, null, 0L);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int getExpanded() {
|
||||||
|
return children.length - Long.bitCount(unexpandedMoves);
|
||||||
|
}
|
||||||
|
|
||||||
public boolean isFullyExpanded() {
|
public boolean isFullyExpanded() {
|
||||||
return expanded == children.length;
|
return unexpandedMoves == 0L;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float calculateUCT(int parentVisits) {
|
public float calculateUCT(int parentVisits) {
|
||||||
@@ -54,12 +59,15 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
final float exploitation = value / visits;
|
final float exploitation = value / visits;
|
||||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||||
|
final float bias = heuristic * 10.0f / (visits + 1);
|
||||||
|
|
||||||
return exploitation + exploration;
|
return exploitation + exploration + bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node bestUCTChild() {
|
public Node bestUCTChild() {
|
||||||
|
final int expanded = getExpanded();
|
||||||
|
|
||||||
Node highestUCTChild = null;
|
Node highestUCTChild = null;
|
||||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||||
|
|
||||||
@@ -76,117 +84,34 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final Random random = new Random();
|
protected static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
||||||
|
|
||||||
private final int milliseconds;
|
protected final int milliseconds;
|
||||||
|
|
||||||
private Node root;
|
protected int lastIterations;
|
||||||
|
|
||||||
public MCTSAI2(int milliseconds) {
|
public MCTSAI(int milliseconds) {
|
||||||
this.milliseconds = milliseconds;
|
this.milliseconds = milliseconds;
|
||||||
|
|
||||||
this.root = null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public MCTSAI2(MCTSAI2 other) {
|
public MCTSAI(MCTSAI other) {
|
||||||
this.random = other.random;
|
|
||||||
this.milliseconds = other.milliseconds;
|
this.milliseconds = other.milliseconds;
|
||||||
|
|
||||||
this.root = other.root;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public int getLastIterations() {
|
||||||
public MCTSAI2 deepCopy() {
|
return lastIterations;
|
||||||
return new MCTSAI2(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
protected Node selection(Node root) {
|
||||||
public long getMove(TurnBasedGame game) {
|
// while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||||
root = findOrResetRoot(root, game);
|
while (root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||||
|
|
||||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
|
||||||
|
|
||||||
while (System.nanoTime() < endTime) {
|
|
||||||
Node leaf = selection(root);
|
|
||||||
leaf = expansion(leaf);
|
|
||||||
final float value = simulation(leaf);
|
|
||||||
backPropagation(leaf, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node mostVisitedChild = mostVisitedChild(root);
|
|
||||||
final long move = mostVisitedChild.move;
|
|
||||||
|
|
||||||
root = findChildByMove(root, move);
|
|
||||||
|
|
||||||
return move;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node mostVisitedChild(Node root) {
|
|
||||||
Node mostVisitedChild = null;
|
|
||||||
int mostVisited = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].visits > mostVisited) {
|
|
||||||
mostVisitedChild = root.children[i];
|
|
||||||
mostVisited = root.children[i].visits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mostVisitedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findOrResetRoot(Node root, T game) {
|
|
||||||
if (root == null) {
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findChildByMove(Node root, long move) {
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].move == move) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
|
||||||
if (state1.length != state2.length) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < state1.length; i++) {
|
|
||||||
if (state1[i] != state2[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
private Node selection(Node root) {
|
|
||||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
|
||||||
root = root.bestUCTChild();
|
root = root.bestUCTChild();
|
||||||
}
|
}
|
||||||
|
|
||||||
return root;
|
return root;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node expansion(Node leaf) {
|
protected Node expansion(Node leaf) {
|
||||||
if (leaf.unexpandedMoves == 0L) {
|
if (leaf.unexpandedMoves == 0L) {
|
||||||
return leaf;
|
return leaf;
|
||||||
}
|
}
|
||||||
@@ -198,15 +123,13 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
|
|
||||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||||
|
|
||||||
leaf.children[leaf.expanded] = expandedChild;
|
leaf.children[leaf.getExpanded()] = expandedChild;
|
||||||
leaf.expanded++;
|
|
||||||
|
|
||||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||||
|
|
||||||
return expandedChild;
|
return expandedChild;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float simulation(Node leaf) {
|
protected float simulation(Node leaf) {
|
||||||
final TurnBasedGame copiedState = leaf.state.deepCopy();
|
final TurnBasedGame copiedState = leaf.state.deepCopy();
|
||||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||||
|
|
||||||
@@ -228,12 +151,12 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
return 0.0f;
|
return 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void backPropagation(Node leaf, float value) {
|
protected void backPropagation(Node leaf, float value) {
|
||||||
while (leaf != null) {
|
while (leaf != null) {
|
||||||
leaf.value += value;
|
leaf.value += value;
|
||||||
leaf.visits++;
|
leaf.visits++;
|
||||||
|
|
||||||
if (!leaf.solved) {
|
if (Float.isNaN(leaf.solved)) {
|
||||||
updateSolvedStatus(leaf);
|
updateSolvedStatus(leaf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,14 +165,91 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected Node mostVisitedChild(Node root) {
|
||||||
|
final int expanded = root.getExpanded();
|
||||||
|
|
||||||
|
Node mostVisitedChild = null;
|
||||||
|
int mostVisited = -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
if (root.children[i].visits > mostVisited) {
|
||||||
|
mostVisitedChild = root.children[i];
|
||||||
|
mostVisited = root.children[i].visits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mostVisitedChild;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Node findOrResetRoot(Node root, T game) {
|
||||||
|
if (root == null) {
|
||||||
|
return new Node(game.deepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
final int expanded = root.getExpanded();
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
||||||
|
root.children[i].parent = null;
|
||||||
|
return root.children[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Node(game.deepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Node findChildByMove(Node root, long move) {
|
||||||
|
final int expanded = root.getExpanded();
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
if (root.children[i].move == move) {
|
||||||
|
root.children[i].parent = null;
|
||||||
|
return root.children[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected boolean areStatesEqual(long[] state1, long[] state2) {
|
||||||
|
if (state1.length != state2.length) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < state1.length; i++) {
|
||||||
|
if (state1[i] != state2[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected long randomSetBit(long value) {
|
||||||
|
if (0L == value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
final int bitCount = Long.bitCount(value);
|
||||||
|
final int randomBitCount = random.get().nextInt(bitCount);
|
||||||
|
|
||||||
|
for (int i = 0; i < randomBitCount; i++) {
|
||||||
|
value &= value - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return value & -value;
|
||||||
|
}
|
||||||
|
|
||||||
private void updateSolvedStatus(Node node) {
|
private void updateSolvedStatus(Node node) {
|
||||||
if (node.state.isTerminal()) {
|
if (node.state.isTerminal()) {
|
||||||
node.solved = true;
|
|
||||||
|
|
||||||
final int winner = node.state.getWinner();
|
final int winner = node.state.getWinner();
|
||||||
final int mover = 1 - node.state.getCurrentTurn();
|
final int mover = 1 - node.state.getCurrentTurn();
|
||||||
|
|
||||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
node.solved = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -260,13 +260,13 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
boolean foundDrawMove = false;
|
boolean foundDrawMove = false;
|
||||||
|
|
||||||
for (final Node child : node.children) {
|
for (final Node child : node.children) {
|
||||||
if (child.solved) {
|
if (!Float.isNaN(child.solved)) {
|
||||||
if (child.solvedValue == -1.0f) {
|
if (child.solved == -1.0f) {
|
||||||
foundWinningMove = true;
|
foundWinningMove = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (child.solvedValue == 0.0f) {
|
if (child.solved == 0.0f) {
|
||||||
foundDrawMove = true;
|
foundDrawMove = true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -275,27 +275,10 @@ public class MCTSAI2 extends AbstractAI {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (foundWinningMove) {
|
if (foundWinningMove) {
|
||||||
node.solved = true;
|
node.solved = 1.0f;
|
||||||
node.solvedValue = 1.0f;
|
|
||||||
} else if (allChildrenSolved) {
|
} else if (allChildrenSolved) {
|
||||||
node.solved = true;
|
node.solved = foundDrawMove? 0.0f : -1.0f;
|
||||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private long randomSetBit(long value) {
|
|
||||||
if (0L == value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int bitCount = Long.bitCount(value);
|
|
||||||
final int randomBitCount = random.nextInt(bitCount);
|
|
||||||
|
|
||||||
for (int i = 0; i < randomBitCount; i++) {
|
|
||||||
value &= value - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return value & -value;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
package org.toop.game.players.ai;
|
|
||||||
|
|
||||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
|
||||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
public class MCTSAI1<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
|
||||||
private static class Node {
|
|
||||||
public TurnBasedGame<?> state;
|
|
||||||
|
|
||||||
public long move;
|
|
||||||
public long unexpandedMoves;
|
|
||||||
|
|
||||||
public Node parent;
|
|
||||||
|
|
||||||
public Node[] children;
|
|
||||||
public int expanded;
|
|
||||||
|
|
||||||
public float value;
|
|
||||||
public int visits;
|
|
||||||
|
|
||||||
public boolean solved;
|
|
||||||
public float solvedValue;
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
|
||||||
final long legalMoves = state.getLegalMoves();
|
|
||||||
|
|
||||||
this.state = state;
|
|
||||||
|
|
||||||
this.move = move;
|
|
||||||
this.unexpandedMoves = legalMoves;
|
|
||||||
|
|
||||||
this.parent = parent;
|
|
||||||
|
|
||||||
this.children = new Node[Long.bitCount(legalMoves)];
|
|
||||||
this.expanded = 0;
|
|
||||||
|
|
||||||
this.value = 0.0f;
|
|
||||||
this.visits = 0;
|
|
||||||
|
|
||||||
this.solved = false;
|
|
||||||
this.solvedValue = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state) {
|
|
||||||
this(state, null, 0L);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isFullyExpanded() {
|
|
||||||
return expanded == children.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateUCT(int parentVisits) {
|
|
||||||
if (visits == 0) {
|
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
final float exploitation = value / visits;
|
|
||||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
|
||||||
|
|
||||||
return exploitation + exploration;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node bestUCTChild() {
|
|
||||||
Node highestUCTChild = null;
|
|
||||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
|
||||||
|
|
||||||
for (int i = 0; i < expanded; i++) {
|
|
||||||
final float childUCT = children[i].calculateUCT(visits);
|
|
||||||
|
|
||||||
if (childUCT > highestUCT) {
|
|
||||||
highestUCTChild = children[i];
|
|
||||||
highestUCT = childUCT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return highestUCTChild;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final Random random = new Random();
|
|
||||||
|
|
||||||
private final int milliseconds;
|
|
||||||
|
|
||||||
public MCTSAI1(int milliseconds) {
|
|
||||||
this.milliseconds = milliseconds;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MCTSAI1(MCTSAI1<T> other) {
|
|
||||||
this.milliseconds = other.milliseconds;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MCTSAI1<T> deepCopy() {
|
|
||||||
return new MCTSAI1<>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getMove(T game) {
|
|
||||||
final Node root = new Node(game, null, 0L);
|
|
||||||
|
|
||||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
|
||||||
|
|
||||||
while (System.nanoTime() < endTime) {
|
|
||||||
Node leaf = selection(root);
|
|
||||||
leaf = expansion(leaf);
|
|
||||||
final float value = simulation(leaf);
|
|
||||||
backPropagation(leaf, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node mostVisitedChild = mostVisitedChild(root);
|
|
||||||
return mostVisitedChild.move;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node mostVisitedChild(Node root) {
|
|
||||||
Node mostVisitedChild = null;
|
|
||||||
int mostVisited = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].visits > mostVisited) {
|
|
||||||
mostVisitedChild = root.children[i];
|
|
||||||
mostVisited = root.children[i].visits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mostVisitedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node selection(Node root) {
|
|
||||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
|
||||||
root = root.bestUCTChild();
|
|
||||||
}
|
|
||||||
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node expansion(Node leaf) {
|
|
||||||
if (leaf.unexpandedMoves == 0L) {
|
|
||||||
return leaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
|
||||||
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
copiedState.play(unexpandedMove);
|
|
||||||
|
|
||||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
|
||||||
|
|
||||||
leaf.children[leaf.expanded] = expandedChild;
|
|
||||||
leaf.expanded++;
|
|
||||||
|
|
||||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
|
||||||
|
|
||||||
return expandedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private float simulation(Node leaf) {
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
|
||||||
|
|
||||||
while (!copiedState.isTerminal()) {
|
|
||||||
final long legalMoves = copiedState.getLegalMoves();
|
|
||||||
final long randomMove = randomSetBit(legalMoves);
|
|
||||||
|
|
||||||
copiedState.play(randomMove);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() == playerIndex) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() >= 0) {
|
|
||||||
return -1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void backPropagation(Node leaf, float value) {
|
|
||||||
while (leaf != null) {
|
|
||||||
leaf.value += value;
|
|
||||||
leaf.visits++;
|
|
||||||
|
|
||||||
if (!leaf.solved) {
|
|
||||||
updateSolvedStatus(leaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
value = -value;
|
|
||||||
leaf = leaf.parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateSolvedStatus(Node node) {
|
|
||||||
if (node.state.isTerminal()) {
|
|
||||||
node.solved = true;
|
|
||||||
|
|
||||||
final int winner = node.state.getWinner();
|
|
||||||
final int mover = 1 - node.state.getCurrentTurn();
|
|
||||||
|
|
||||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node.isFullyExpanded()) {
|
|
||||||
boolean allChildrenSolved = true;
|
|
||||||
boolean foundWinningMove = false;
|
|
||||||
boolean foundDrawMove = false;
|
|
||||||
|
|
||||||
for (final Node child : node.children) {
|
|
||||||
if (child.solved) {
|
|
||||||
if (child.solvedValue == -1.0f) {
|
|
||||||
foundWinningMove = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (child.solvedValue == 0.0f) {
|
|
||||||
foundDrawMove = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
allChildrenSolved = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (foundWinningMove) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = 1.0f;
|
|
||||||
} else if (allChildrenSolved) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private long randomSetBit(long value) {
|
|
||||||
if (0L == value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int bitCount = Long.bitCount(value);
|
|
||||||
final int randomBitCount = random.nextInt(bitCount);
|
|
||||||
|
|
||||||
for (int i = 0; i < randomBitCount; i++) {
|
|
||||||
value &= value - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return value & -value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,354 +0,0 @@
|
|||||||
package org.toop.game.players.ai;
|
|
||||||
|
|
||||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
|
||||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.Future;
|
|
||||||
|
|
||||||
public class MCTSAI3 extends AbstractAI {
|
|
||||||
private static class Node {
|
|
||||||
public TurnBasedGame state;
|
|
||||||
|
|
||||||
public long move;
|
|
||||||
public long unexpandedMoves;
|
|
||||||
|
|
||||||
public Node parent;
|
|
||||||
|
|
||||||
public Node[] children;
|
|
||||||
public int expanded;
|
|
||||||
|
|
||||||
public float value;
|
|
||||||
public int visits;
|
|
||||||
|
|
||||||
public boolean solved;
|
|
||||||
public float solvedValue;
|
|
||||||
|
|
||||||
public Node(TurnBasedGame state, Node parent, long move) {
|
|
||||||
final long legalMoves = state.getLegalMoves();
|
|
||||||
|
|
||||||
this.state = state;
|
|
||||||
|
|
||||||
this.move = move;
|
|
||||||
this.unexpandedMoves = legalMoves;
|
|
||||||
|
|
||||||
this.parent = parent;
|
|
||||||
|
|
||||||
this.children = new Node[Long.bitCount(legalMoves)];
|
|
||||||
this.expanded = 0;
|
|
||||||
|
|
||||||
this.value = 0.0f;
|
|
||||||
this.visits = 0;
|
|
||||||
|
|
||||||
this.solved = false;
|
|
||||||
this.solvedValue = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node(TurnBasedGame state) {
|
|
||||||
this(state, null, 0L);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isFullyExpanded() {
|
|
||||||
return expanded == children.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateUCT(int parentVisits) {
|
|
||||||
if (visits == 0) {
|
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
final float exploitation = value / visits;
|
|
||||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
|
||||||
|
|
||||||
return exploitation + exploration;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node bestUCTChild() {
|
|
||||||
Node highestUCTChild = null;
|
|
||||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
|
||||||
|
|
||||||
for (int i = 0; i < expanded; i++) {
|
|
||||||
final float childUCT = children[i].calculateUCT(visits);
|
|
||||||
|
|
||||||
if (childUCT > highestUCT) {
|
|
||||||
highestUCTChild = children[i];
|
|
||||||
highestUCT = childUCT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return highestUCTChild;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
|
||||||
|
|
||||||
private final int milliseconds;
|
|
||||||
private final int threads;
|
|
||||||
|
|
||||||
public MCTSAI3(int milliseconds, int threads) {
|
|
||||||
this.milliseconds = milliseconds;
|
|
||||||
this.threads = threads;
|
|
||||||
}
|
|
||||||
|
|
||||||
public MCTSAI3(MCTSAI3 other) {
|
|
||||||
this.random = other.random;
|
|
||||||
|
|
||||||
this.root = other.root;
|
|
||||||
this.milliseconds = other.milliseconds;
|
|
||||||
this.threads = other.threads;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MCTSAI3 deepCopy() {
|
|
||||||
return new MCTSAI3(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getMove(T game) {
|
|
||||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
|
||||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
|
||||||
|
|
||||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
tasks.add(() -> {
|
|
||||||
final Node localRoot = new Node(game.deepCopy());
|
|
||||||
|
|
||||||
while (System.nanoTime() < endTime) {
|
|
||||||
Node leaf = selection(localRoot);
|
|
||||||
leaf = expansion(leaf);
|
|
||||||
final float value = simulation(leaf);
|
|
||||||
backPropagation(leaf, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return localRoot;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
|
||||||
|
|
||||||
pool.shutdown();
|
|
||||||
|
|
||||||
final Node root = new Node(game.deepCopy());
|
|
||||||
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
expansion(root);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (final Future<Node> result : results) {
|
|
||||||
final Node localRoot = result.get();
|
|
||||||
|
|
||||||
for (final Node localChild : localRoot.children) {
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
if (localChild.move == root.children[i].move) {
|
|
||||||
root.children[i].visits += localChild.visits;
|
|
||||||
root.visits += localChild.visits;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node mostVisitedChild = mostVisitedChild(root);
|
|
||||||
return mostVisitedChild.move;
|
|
||||||
} catch (Exception _) {
|
|
||||||
final long legalMoves = game.getLegalMoves();
|
|
||||||
return randomSetBit(legalMoves);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node mostVisitedChild(Node root) {
|
|
||||||
Node mostVisitedChild = null;
|
|
||||||
int mostVisited = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].visits > mostVisited) {
|
|
||||||
mostVisitedChild = root.children[i];
|
|
||||||
mostVisited = root.children[i].visits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mostVisitedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void detectRoot(TurnBasedGame game) {
|
|
||||||
if (root == null) {
|
|
||||||
root = new Node(game.deepCopy());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
final long[] currentBoards = game.getBoard();
|
|
||||||
final long[] rootBoards = root.state.getBoard();
|
|
||||||
|
|
||||||
boolean detected = true;
|
|
||||||
|
|
||||||
for (int i = 0; i < rootBoards.length; i++) {
|
|
||||||
if (rootBoards[i] != currentBoards[i]) {
|
|
||||||
detected = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (detected) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
final Node child = root.children[i];
|
|
||||||
|
|
||||||
final long[] childBoards = child.state.getBoard();
|
|
||||||
|
|
||||||
detected = true;
|
|
||||||
|
|
||||||
for (int j = 0; j < childBoards.length; j++) {
|
|
||||||
if (childBoards[j] != currentBoards[j]) {
|
|
||||||
detected = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (detected) {
|
|
||||||
root = child;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
root = new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
private void newRoot(long move) {
|
|
||||||
for (final Node child : root.children) {
|
|
||||||
if (child.move == move) {
|
|
||||||
root = child;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node selection(Node root) {
|
|
||||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
|
||||||
root = root.bestUCTChild();
|
|
||||||
}
|
|
||||||
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node expansion(Node leaf) {
|
|
||||||
if (leaf.unexpandedMoves == 0L) {
|
|
||||||
return leaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
|
||||||
|
|
||||||
final TurnBasedGame copiedState = leaf.state.deepCopy();
|
|
||||||
copiedState.play(unexpandedMove);
|
|
||||||
|
|
||||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
|
||||||
|
|
||||||
leaf.children[leaf.expanded] = expandedChild;
|
|
||||||
leaf.expanded++;
|
|
||||||
|
|
||||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
|
||||||
|
|
||||||
return expandedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private float simulation(Node leaf) {
|
|
||||||
final TurnBasedGame copiedState = leaf.state.deepCopy();
|
|
||||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
|
||||||
|
|
||||||
while (!copiedState.isTerminal()) {
|
|
||||||
final long legalMoves = copiedState.getLegalMoves();
|
|
||||||
final long randomMove = randomSetBit(legalMoves);
|
|
||||||
|
|
||||||
copiedState.play(randomMove);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() == playerIndex) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() >= 0) {
|
|
||||||
return -1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void backPropagation(Node leaf, float value) {
|
|
||||||
while (leaf != null) {
|
|
||||||
leaf.value += value;
|
|
||||||
leaf.visits++;
|
|
||||||
|
|
||||||
if (!leaf.solved) {
|
|
||||||
updateSolvedStatus(leaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
value = -value;
|
|
||||||
leaf = leaf.parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateSolvedStatus(Node node) {
|
|
||||||
if (node.state.isTerminal()) {
|
|
||||||
node.solved = true;
|
|
||||||
|
|
||||||
final int winner = node.state.getWinner();
|
|
||||||
final int mover = 1 - node.state.getCurrentTurn();
|
|
||||||
|
|
||||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node.isFullyExpanded()) {
|
|
||||||
boolean allChildrenSolved = true;
|
|
||||||
boolean foundWinningMove = false;
|
|
||||||
boolean foundDrawMove = false;
|
|
||||||
|
|
||||||
for (final Node child : node.children) {
|
|
||||||
if (child.solved) {
|
|
||||||
if (child.solvedValue == -1.0f) {
|
|
||||||
foundWinningMove = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (child.solvedValue == 0.0f) {
|
|
||||||
foundDrawMove = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
allChildrenSolved = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (foundWinningMove) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = 1.0f;
|
|
||||||
} else if (allChildrenSolved) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private long randomSetBit(long value) {
|
|
||||||
if (0L == value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int bitCount = Long.bitCount(value);
|
|
||||||
final int randomBitCount = random.get().nextInt(bitCount);
|
|
||||||
|
|
||||||
for (int i = 0; i < randomBitCount; i++) {
|
|
||||||
value &= value - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return value & -value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,359 +0,0 @@
|
|||||||
package org.toop.game.players.ai;
|
|
||||||
|
|
||||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
|
||||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.Future;
|
|
||||||
|
|
||||||
public class MCTSAI4<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
|
||||||
private static class Node {
|
|
||||||
public TurnBasedGame<?> state;
|
|
||||||
|
|
||||||
public long move;
|
|
||||||
public long unexpandedMoves;
|
|
||||||
|
|
||||||
public Node parent;
|
|
||||||
|
|
||||||
public Node[] children;
|
|
||||||
public int expanded;
|
|
||||||
|
|
||||||
public float value;
|
|
||||||
public int visits;
|
|
||||||
|
|
||||||
public boolean solved;
|
|
||||||
public float solvedValue;
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
|
||||||
final long legalMoves = state.getLegalMoves();
|
|
||||||
|
|
||||||
this.state = state;
|
|
||||||
|
|
||||||
this.move = move;
|
|
||||||
this.unexpandedMoves = legalMoves;
|
|
||||||
|
|
||||||
this.parent = parent;
|
|
||||||
|
|
||||||
this.children = new Node[Long.bitCount(legalMoves)];
|
|
||||||
this.expanded = 0;
|
|
||||||
|
|
||||||
this.value = 0.0f;
|
|
||||||
this.visits = 0;
|
|
||||||
|
|
||||||
this.solved = false;
|
|
||||||
this.solvedValue = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state) {
|
|
||||||
this(state, null, 0L);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isFullyExpanded() {
|
|
||||||
return expanded == children.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateUCT(int parentVisits) {
|
|
||||||
if (visits == 0) {
|
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
final float exploitation = value / visits;
|
|
||||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
|
||||||
|
|
||||||
return exploitation + exploration;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node bestUCTChild() {
|
|
||||||
Node highestUCTChild = null;
|
|
||||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
|
||||||
|
|
||||||
for (int i = 0; i < expanded; i++) {
|
|
||||||
final float childUCT = children[i].calculateUCT(visits);
|
|
||||||
|
|
||||||
if (childUCT > highestUCT) {
|
|
||||||
highestUCTChild = children[i];
|
|
||||||
highestUCT = childUCT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return highestUCTChild;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
|
||||||
|
|
||||||
private final int milliseconds;
|
|
||||||
private final int threads;
|
|
||||||
|
|
||||||
private final Node[] threadRoots;
|
|
||||||
|
|
||||||
public MCTSAI4(int milliseconds, int threads) {
|
|
||||||
this.milliseconds = milliseconds;
|
|
||||||
this.threads = threads;
|
|
||||||
|
|
||||||
this.threadRoots = new Node[threads];
|
|
||||||
}
|
|
||||||
|
|
||||||
public MCTSAI4(MCTSAI4<T> other) {
|
|
||||||
this.milliseconds = other.milliseconds;
|
|
||||||
this.threads = other.threads;
|
|
||||||
|
|
||||||
this.threadRoots = other.threadRoots;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MCTSAI4<T> deepCopy() {
|
|
||||||
return new MCTSAI4<>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getMove(T game) {
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
|
|
||||||
}
|
|
||||||
|
|
||||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
|
||||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
|
||||||
|
|
||||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
final int threadIndex = i;
|
|
||||||
|
|
||||||
tasks.add(() -> {
|
|
||||||
final Node localRoot = threadRoots[threadIndex];
|
|
||||||
|
|
||||||
while (System.nanoTime() < endTime) {
|
|
||||||
Node leaf = selection(localRoot);
|
|
||||||
leaf = expansion(leaf);
|
|
||||||
final float value = simulation(leaf);
|
|
||||||
backPropagation(leaf, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return localRoot;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
|
||||||
|
|
||||||
pool.shutdown();
|
|
||||||
|
|
||||||
final Node root = new Node(game.deepCopy());
|
|
||||||
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
expansion(root);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (final Future<Node> result : results) {
|
|
||||||
final Node localRoot = result.get();
|
|
||||||
|
|
||||||
for (final Node localChild : localRoot.children) {
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
if (localChild.move == root.children[i].move) {
|
|
||||||
root.children[i].visits += localChild.visits;
|
|
||||||
root.visits += localChild.visits;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node mostVisitedChild = mostVisitedChild(root);
|
|
||||||
final long move = mostVisitedChild.move;
|
|
||||||
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
threadRoots[i] = findChildByMove(threadRoots[i], move);
|
|
||||||
}
|
|
||||||
|
|
||||||
return move;
|
|
||||||
} catch (Exception _) {
|
|
||||||
final long legalMoves = game.getLegalMoves();
|
|
||||||
return randomSetBit(legalMoves);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node mostVisitedChild(Node root) {
|
|
||||||
Node mostVisitedChild = null;
|
|
||||||
int mostVisited = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].visits > mostVisited) {
|
|
||||||
mostVisitedChild = root.children[i];
|
|
||||||
mostVisited = root.children[i].visits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mostVisitedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findOrResetRoot(Node root, T game) {
|
|
||||||
if (root == null) {
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findChildByMove(Node root, long move) {
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].move == move) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
|
||||||
if (state1.length != state2.length) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < state1.length; i++) {
|
|
||||||
if (state1[i] != state2[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node selection(Node root) {
|
|
||||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
|
||||||
root = root.bestUCTChild();
|
|
||||||
}
|
|
||||||
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node expansion(Node leaf) {
|
|
||||||
if (leaf.unexpandedMoves == 0L) {
|
|
||||||
return leaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
|
||||||
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
copiedState.play(unexpandedMove);
|
|
||||||
|
|
||||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
|
||||||
|
|
||||||
leaf.children[leaf.expanded] = expandedChild;
|
|
||||||
leaf.expanded++;
|
|
||||||
|
|
||||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
|
||||||
|
|
||||||
return expandedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private float simulation(Node leaf) {
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
|
||||||
|
|
||||||
while (!copiedState.isTerminal()) {
|
|
||||||
final long legalMoves = copiedState.getLegalMoves();
|
|
||||||
final long randomMove = randomSetBit(legalMoves);
|
|
||||||
|
|
||||||
copiedState.play(randomMove);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() == playerIndex) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() >= 0) {
|
|
||||||
return -1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void backPropagation(Node leaf, float value) {
|
|
||||||
while (leaf != null) {
|
|
||||||
leaf.value += value;
|
|
||||||
leaf.visits++;
|
|
||||||
|
|
||||||
if (!leaf.solved) {
|
|
||||||
updateSolvedStatus(leaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
value = -value;
|
|
||||||
leaf = leaf.parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateSolvedStatus(Node node) {
|
|
||||||
if (node.state.isTerminal()) {
|
|
||||||
node.solved = true;
|
|
||||||
|
|
||||||
final int winner = node.state.getWinner();
|
|
||||||
final int mover = 1 - node.state.getCurrentTurn();
|
|
||||||
|
|
||||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node.isFullyExpanded()) {
|
|
||||||
boolean allChildrenSolved = true;
|
|
||||||
boolean foundWinningMove = false;
|
|
||||||
boolean foundDrawMove = false;
|
|
||||||
|
|
||||||
for (final Node child : node.children) {
|
|
||||||
if (child.solved) {
|
|
||||||
if (child.solvedValue == -1.0f) {
|
|
||||||
foundWinningMove = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (child.solvedValue == 0.0f) {
|
|
||||||
foundDrawMove = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
allChildrenSolved = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (foundWinningMove) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = 1.0f;
|
|
||||||
} else if (allChildrenSolved) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private long randomSetBit(long value) {
|
|
||||||
if (0L == value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int bitCount = Long.bitCount(value);
|
|
||||||
final int randomBitCount = random.get().nextInt(bitCount);
|
|
||||||
|
|
||||||
for (int i = 0; i < randomBitCount; i++) {
|
|
||||||
value &= value - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return value & -value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,371 +0,0 @@
|
|||||||
package org.toop.game.players.ai;
|
|
||||||
|
|
||||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
|
||||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Random;
|
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.Future;
|
|
||||||
|
|
||||||
public class MCTSAI5<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
|
||||||
private static class Node {
|
|
||||||
public TurnBasedGame<?> state;
|
|
||||||
|
|
||||||
public long move;
|
|
||||||
public long unexpandedMoves;
|
|
||||||
|
|
||||||
public Node parent;
|
|
||||||
|
|
||||||
public Node[] children;
|
|
||||||
public int expanded;
|
|
||||||
|
|
||||||
public float value;
|
|
||||||
public int visits;
|
|
||||||
|
|
||||||
public boolean solved;
|
|
||||||
public float solvedValue;
|
|
||||||
|
|
||||||
public float heuristic;
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
|
||||||
final long legalMoves = state.getLegalMoves();
|
|
||||||
|
|
||||||
this.state = state;
|
|
||||||
|
|
||||||
this.move = move;
|
|
||||||
this.unexpandedMoves = legalMoves;
|
|
||||||
|
|
||||||
this.parent = parent;
|
|
||||||
|
|
||||||
this.children = new Node[Long.bitCount(legalMoves)];
|
|
||||||
this.expanded = 0;
|
|
||||||
|
|
||||||
this.value = 0.0f;
|
|
||||||
this.visits = 0;
|
|
||||||
|
|
||||||
this.solved = false;
|
|
||||||
this.solvedValue = 0.0f;
|
|
||||||
|
|
||||||
this.heuristic = state.rateMove(move);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node(TurnBasedGame<?> state) {
|
|
||||||
this(state, null, 0L);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isFullyExpanded() {
|
|
||||||
return expanded == children.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float calculateUCT(int parentVisits) {
|
|
||||||
if (visits == 0) {
|
|
||||||
return Float.POSITIVE_INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
final float exploitation = value / visits;
|
|
||||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
|
||||||
final float bias = heuristic / visits;
|
|
||||||
|
|
||||||
return exploitation + exploration + bias;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Node bestUCTChild() {
|
|
||||||
Node highestUCTChild = null;
|
|
||||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
|
||||||
|
|
||||||
for (int i = 0; i < expanded; i++) {
|
|
||||||
final float childUCT = children[i].calculateUCT(visits);
|
|
||||||
|
|
||||||
if (childUCT > highestUCT) {
|
|
||||||
highestUCTChild = children[i];
|
|
||||||
highestUCT = childUCT;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return highestUCTChild;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
|
||||||
|
|
||||||
private final int milliseconds;
|
|
||||||
private final int threads;
|
|
||||||
|
|
||||||
private final Node[] threadRoots;
|
|
||||||
|
|
||||||
public MCTSAI5(int milliseconds, int threads) {
|
|
||||||
this.milliseconds = milliseconds;
|
|
||||||
this.threads = threads;
|
|
||||||
|
|
||||||
this.threadRoots = new Node[threads];
|
|
||||||
}
|
|
||||||
|
|
||||||
public MCTSAI5(MCTSAI5<T> other) {
|
|
||||||
this.milliseconds = other.milliseconds;
|
|
||||||
this.threads = other.threads;
|
|
||||||
|
|
||||||
this.threadRoots = other.threadRoots;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MCTSAI5<T> deepCopy() {
|
|
||||||
return new MCTSAI5<>(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getMove(T game) {
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
|
|
||||||
}
|
|
||||||
|
|
||||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
|
||||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
|
||||||
|
|
||||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
final int threadIndex = i;
|
|
||||||
|
|
||||||
tasks.add(() -> {
|
|
||||||
final Node localRoot = threadRoots[threadIndex];
|
|
||||||
|
|
||||||
while (System.nanoTime() < endTime) {
|
|
||||||
Node leaf = selection(localRoot);
|
|
||||||
leaf = expansion(leaf);
|
|
||||||
final float value = simulation(leaf);
|
|
||||||
backPropagation(leaf, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return localRoot;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
|
||||||
|
|
||||||
pool.shutdown();
|
|
||||||
|
|
||||||
final Node root = new Node(game.deepCopy());
|
|
||||||
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
expansion(root);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (final Future<Node> result : results) {
|
|
||||||
final Node localRoot = result.get();
|
|
||||||
|
|
||||||
for (final Node localChild : localRoot.children) {
|
|
||||||
for (int i = 0; i < root.children.length; i++) {
|
|
||||||
if (localChild.move == root.children[i].move) {
|
|
||||||
root.children[i].visits += localChild.visits;
|
|
||||||
root.visits += localChild.visits;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
final Node mostVisitedChild = mostVisitedChild(root);
|
|
||||||
final long move = mostVisitedChild.move;
|
|
||||||
|
|
||||||
for (int i = 0; i < threads; i++) {
|
|
||||||
threadRoots[i] = findChildByMove(threadRoots[i], move);
|
|
||||||
}
|
|
||||||
|
|
||||||
return move;
|
|
||||||
} catch (Exception _) {
|
|
||||||
final long legalMoves = game.getLegalMoves();
|
|
||||||
return randomSetBit(legalMoves);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node mostVisitedChild(Node root) {
|
|
||||||
Node mostVisitedChild = null;
|
|
||||||
int mostVisited = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].visits > mostVisited) {
|
|
||||||
mostVisitedChild = root.children[i];
|
|
||||||
mostVisited = root.children[i].visits;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mostVisitedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findOrResetRoot(Node root, T game) {
|
|
||||||
if (root == null) {
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Node(game.deepCopy());
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node findChildByMove(Node root, long move) {
|
|
||||||
for (int i = 0; i < root.expanded; i++) {
|
|
||||||
if (root.children[i].move == move) {
|
|
||||||
root.children[i].parent = null;
|
|
||||||
return root.children[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
|
||||||
if (state1.length != state2.length) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < state1.length; i++) {
|
|
||||||
if (state1[i] != state2[i]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node selection(Node root) {
|
|
||||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
|
||||||
root = root.bestUCTChild();
|
|
||||||
}
|
|
||||||
|
|
||||||
return root;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node expansion(Node leaf) {
|
|
||||||
if (leaf.unexpandedMoves == 0L) {
|
|
||||||
return leaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
|
||||||
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
copiedState.play(unexpandedMove);
|
|
||||||
|
|
||||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
|
||||||
|
|
||||||
leaf.children[leaf.expanded] = expandedChild;
|
|
||||||
leaf.expanded++;
|
|
||||||
|
|
||||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
|
||||||
|
|
||||||
return expandedChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
private float simulation(Node leaf) {
|
|
||||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
|
||||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
|
||||||
|
|
||||||
while (!copiedState.isTerminal()) {
|
|
||||||
final long legalMoves = copiedState.getLegalMoves();
|
|
||||||
|
|
||||||
long move = 0L;
|
|
||||||
|
|
||||||
if (random.get().nextFloat() > 0.9f) {
|
|
||||||
move = copiedState.heuristicMove(legalMoves);
|
|
||||||
} else {
|
|
||||||
move = randomSetBit(legalMoves);
|
|
||||||
}
|
|
||||||
|
|
||||||
copiedState.play(move);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() == playerIndex) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (copiedState.getWinner() >= 0) {
|
|
||||||
return -1.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void backPropagation(Node leaf, float value) {
|
|
||||||
while (leaf != null) {
|
|
||||||
leaf.value += value;
|
|
||||||
leaf.visits++;
|
|
||||||
|
|
||||||
if (!leaf.solved) {
|
|
||||||
updateSolvedStatus(leaf);
|
|
||||||
}
|
|
||||||
|
|
||||||
value = -value;
|
|
||||||
leaf = leaf.parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateSolvedStatus(Node node) {
|
|
||||||
if (node.state.isTerminal()) {
|
|
||||||
node.solved = true;
|
|
||||||
|
|
||||||
final int winner = node.state.getWinner();
|
|
||||||
final int mover = 1 - node.state.getCurrentTurn();
|
|
||||||
|
|
||||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node.isFullyExpanded()) {
|
|
||||||
boolean allChildrenSolved = true;
|
|
||||||
boolean foundWinningMove = false;
|
|
||||||
boolean foundDrawMove = false;
|
|
||||||
|
|
||||||
for (final Node child : node.children) {
|
|
||||||
if (child.solved) {
|
|
||||||
if (child.solvedValue == -1.0f) {
|
|
||||||
foundWinningMove = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (child.solvedValue == 0.0f) {
|
|
||||||
foundDrawMove = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
allChildrenSolved = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (foundWinningMove) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = 1.0f;
|
|
||||||
} else if (allChildrenSolved) {
|
|
||||||
node.solved = true;
|
|
||||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private long randomSetBit(long value) {
|
|
||||||
if (0L == value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int bitCount = Long.bitCount(value);
|
|
||||||
final int randomBitCount = random.get().nextInt(bitCount);
|
|
||||||
|
|
||||||
for (int i = 0; i < randomBitCount; i++) {
|
|
||||||
value &= value - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return value & -value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package org.toop.game.players.ai.mcts;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
|
||||||
|
public class MCTSAI1<T extends TurnBasedGame<T>> extends MCTSAI<T> {
|
||||||
|
public MCTSAI1(int milliseconds) {
|
||||||
|
super(milliseconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI1(MCTSAI1<T> other) {
|
||||||
|
super(other);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI1<T> deepCopy() {
|
||||||
|
return new MCTSAI1<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
final Node root = new Node(game, null, 0L);
|
||||||
|
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
|
||||||
|
while (System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(root);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIterations = root.visits;
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
return mostVisitedChild.move;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package org.toop.game.players.ai.mcts;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
|
||||||
|
public class MCTSAI2<T extends TurnBasedGame<T>> extends MCTSAI<T> {
|
||||||
|
private Node root;
|
||||||
|
|
||||||
|
public MCTSAI2(int milliseconds) {
|
||||||
|
super(milliseconds);
|
||||||
|
|
||||||
|
this.root = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI2(MCTSAI2<T> other) {
|
||||||
|
super(other);
|
||||||
|
|
||||||
|
this.root = other.root;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI2<T> deepCopy() {
|
||||||
|
return new MCTSAI2<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
root = findOrResetRoot(root, game);
|
||||||
|
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
|
||||||
|
while (System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(root);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIterations = root.visits;
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
final long move = mostVisitedChild.move;
|
||||||
|
|
||||||
|
root = findChildByMove(root, move);
|
||||||
|
|
||||||
|
return move;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package org.toop.game.players.ai.mcts;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
|
|
||||||
|
public class MCTSAI3<T extends TurnBasedGame<T>> extends MCTSAI<T> {
|
||||||
|
private final int threads;
|
||||||
|
|
||||||
|
public MCTSAI3(int milliseconds, int threads) {
|
||||||
|
super(milliseconds);
|
||||||
|
|
||||||
|
this.threads = threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI3(MCTSAI3<T> other) {
|
||||||
|
super(other);
|
||||||
|
|
||||||
|
this.threads = other.threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI3<T> deepCopy() {
|
||||||
|
return new MCTSAI3<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
final List<Callable<Node>> tasks = new ArrayList<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < threads; i++) {
|
||||||
|
tasks.add(() -> {
|
||||||
|
final Node localRoot = new Node(game.deepCopy());
|
||||||
|
|
||||||
|
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(localRoot);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return localRoot;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
final List<Future<Node>> results = pool.invokeAll(tasks);
|
||||||
|
|
||||||
|
pool.shutdown();
|
||||||
|
|
||||||
|
final Node root = new Node(game.deepCopy());
|
||||||
|
|
||||||
|
for (int i = 0; i < root.children.length; i++) {
|
||||||
|
expansion(root);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (final Future<Node> result : results) {
|
||||||
|
final Node localRoot = result.get();
|
||||||
|
|
||||||
|
for (final Node localChild : localRoot.children) {
|
||||||
|
for (int i = 0; i < root.children.length; i++) {
|
||||||
|
if (localChild.move == root.children[i].move) {
|
||||||
|
root.children[i].visits += localChild.visits;
|
||||||
|
root.visits += localChild.visits;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIterations = root.visits;
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
return mostVisitedChild.move;
|
||||||
|
} catch (Exception _) {
|
||||||
|
lastIterations = 0;
|
||||||
|
|
||||||
|
final long legalMoves = game.getLegalMoves();
|
||||||
|
return randomSetBit(legalMoves);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
107
game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java
Normal file
107
game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package org.toop.game.players.ai.mcts;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
|
|
||||||
|
public class MCTSAI4<T extends TurnBasedGame<T>> extends MCTSAI<T> {
|
||||||
|
private final int threads;
|
||||||
|
private final Node[] threadRoots;
|
||||||
|
|
||||||
|
public MCTSAI4(int milliseconds, int threads) {
|
||||||
|
super(milliseconds);
|
||||||
|
|
||||||
|
this.threads = threads;
|
||||||
|
this.threadRoots = new Node[threads];
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI4(MCTSAI4<T> other) {
|
||||||
|
super(other);
|
||||||
|
|
||||||
|
this.threads = other.threads;
|
||||||
|
this.threadRoots = other.threadRoots;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI4<T> deepCopy() {
|
||||||
|
return new MCTSAI4<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
for (int i = 0; i < threads; i++) {
|
||||||
|
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
|
||||||
|
}
|
||||||
|
|
||||||
|
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
final List<Callable<Node>> tasks = new ArrayList<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < threads; i++) {
|
||||||
|
final int threadIndex = i;
|
||||||
|
|
||||||
|
tasks.add(() -> {
|
||||||
|
final Node localRoot = threadRoots[threadIndex];
|
||||||
|
|
||||||
|
// while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
|
||||||
|
while (System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(localRoot);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return localRoot;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
final List<Future<Node>> results = pool.invokeAll(tasks);
|
||||||
|
|
||||||
|
pool.shutdown();
|
||||||
|
|
||||||
|
final Node root = new Node(game.deepCopy());
|
||||||
|
|
||||||
|
for (int i = 0; i < root.children.length; i++) {
|
||||||
|
expansion(root);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (final Future<Node> result : results) {
|
||||||
|
final Node localRoot = result.get();
|
||||||
|
|
||||||
|
for (final Node localChild : localRoot.children) {
|
||||||
|
for (int i = 0; i < root.children.length; i++) {
|
||||||
|
if (localChild.move == root.children[i].move) {
|
||||||
|
root.children[i].visits += localChild.visits;
|
||||||
|
root.visits += localChild.visits;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIterations = root.visits;
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
final long move = mostVisitedChild.move;
|
||||||
|
|
||||||
|
for (int i = 0; i < threads; i++) {
|
||||||
|
threadRoots[i] = findChildByMove(threadRoots[i], move);
|
||||||
|
}
|
||||||
|
|
||||||
|
return move;
|
||||||
|
} catch (Exception _) {
|
||||||
|
lastIterations = 0;
|
||||||
|
|
||||||
|
final long legalMoves = game.getLegalMoves();
|
||||||
|
return randomSetBit(legalMoves);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user