mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 19:04:49 +00:00
mcts v1, v2, v3, v4 done. v5 wip
This commit is contained in:
@@ -6,7 +6,6 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.Player;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
// There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy
|
||||
public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBasedGame<T> {
|
||||
@@ -19,7 +18,7 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
||||
|
||||
// long is 64 bits. Every game has a limit of 64 cells maximum.
|
||||
private final long[] playerBitboard;
|
||||
private int currentTurn = 0;
|
||||
protected int currentTurn = 0;
|
||||
|
||||
public BitboardGame(int columnSize, int rowSize, int playerCount, Player<T>[] players) {
|
||||
this.columnSize = columnSize;
|
||||
@@ -82,10 +81,6 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
||||
return (currentTurn + 1) % playerBitboard.length;
|
||||
}
|
||||
|
||||
public Player<T> getCurrentPlayer(){
|
||||
return players[getCurrentPlayerIndex()];
|
||||
}
|
||||
|
||||
@Override
|
||||
public PlayResult getState() {
|
||||
return state;
|
||||
|
||||
@@ -6,9 +6,6 @@ import org.toop.framework.gameFramework.model.player.Player;
|
||||
import org.toop.game.BitboardGame;
|
||||
|
||||
public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
||||
|
||||
public record Score(int black, int white) {}
|
||||
|
||||
private final long notAFile = 0xfefefefefefefefeL;
|
||||
private final long notHFile = 0x7f7f7f7f7f7f7f7fL;
|
||||
|
||||
@@ -253,7 +250,9 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public BitboardReversi deepCopy() {return new BitboardReversi(this);}
|
||||
public BitboardReversi deepCopy() {
|
||||
return new BitboardReversi(this);
|
||||
}
|
||||
|
||||
public PlayResult play(long move) {
|
||||
final long flips = getFlips(move);
|
||||
@@ -296,13 +295,6 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
||||
return state;
|
||||
}
|
||||
|
||||
public Score getScore() {
|
||||
return new Score(
|
||||
Long.bitCount(getPlayerBitboard(0)),
|
||||
Long.bitCount(getPlayerBitboard(1))
|
||||
);
|
||||
}
|
||||
|
||||
public int getWinner(){
|
||||
final long black = getPlayerBitboard(0);
|
||||
final long white = getPlayerBitboard(1);
|
||||
@@ -316,8 +308,51 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
||||
else if (blackCount > whiteCount){
|
||||
return 0;
|
||||
}
|
||||
else{
|
||||
else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float rateMove(long move) {
|
||||
final long corners = 0x8100000000000081L;
|
||||
|
||||
if ((move & corners) != 0L) {
|
||||
return 0.4f;
|
||||
}
|
||||
|
||||
final long xSquares = 0x0042000000004200L;
|
||||
|
||||
if ((move & xSquares) != 0) {
|
||||
return -0.4f;
|
||||
}
|
||||
|
||||
final long cSquares = 0x4281000000008142L;
|
||||
|
||||
if ((move & cSquares) != 0) {
|
||||
return -0.1f;
|
||||
}
|
||||
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long heuristicMove(long legalMoves) {
|
||||
long bestMove = 0L;
|
||||
float bestMoveRate = Float.NEGATIVE_INFINITY;
|
||||
|
||||
while (legalMoves != 0L) {
|
||||
final long move = legalMoves & -legalMoves;
|
||||
final float moveRate = rateMove(move);
|
||||
|
||||
if (moveRate > bestMoveRate) {
|
||||
bestMove = move;
|
||||
bestMoveRate = moveRate;
|
||||
}
|
||||
|
||||
legalMoves &= ~move;
|
||||
}
|
||||
|
||||
return bestMove;
|
||||
}
|
||||
}
|
||||
@@ -104,4 +104,14 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
|
||||
public BitboardTicTacToe deepCopy() {
|
||||
return new BitboardTicTacToe(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float rateMove(long move) {
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long heuristicMove(long legalMoves) {
|
||||
return legalMoves;
|
||||
}
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package org.toop.game.players.ai;
|
||||
|
||||
import org.toop.framework.gameFramework.GameState;
|
||||
import org.toop.framework.gameFramework.model.game.PlayResult;
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
public TurnBasedGame<?> state;
|
||||
public long move;
|
||||
|
||||
public Node parent;
|
||||
|
||||
public int expanded;
|
||||
public Node[] children;
|
||||
|
||||
public int visits;
|
||||
public float value;
|
||||
|
||||
public Node(TurnBasedGame<?> state, long move, Node parent) {
|
||||
this.state = state;
|
||||
this.move = move;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.expanded = 0;
|
||||
this.children = new Node[Long.bitCount(state.getLegalMoves())];
|
||||
|
||||
this.visits = 0;
|
||||
this.value = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
this(state, 0L, null);
|
||||
}
|
||||
|
||||
public boolean isFullyExpanded() {
|
||||
return expanded >= children.length;
|
||||
}
|
||||
|
||||
float calculateUCT() {
|
||||
float exploitation = visits <= 0? 0 : value / visits;
|
||||
float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits));
|
||||
|
||||
return exploitation + exploration;
|
||||
}
|
||||
|
||||
public Node bestUCTChild() {
|
||||
int bestChildIndex = -1;
|
||||
float bestScore = Float.NEGATIVE_INFINITY;
|
||||
|
||||
for (int i = 0; i < expanded; i++) {
|
||||
final float score = calculateUCT();
|
||||
|
||||
if (score > bestScore) {
|
||||
bestChildIndex = i;
|
||||
bestScore = score;
|
||||
}
|
||||
}
|
||||
|
||||
return bestChildIndex >= 0? children[bestChildIndex] : this;
|
||||
}
|
||||
}
|
||||
|
||||
private final int milliseconds;
|
||||
|
||||
public MCTSAI(int milliseconds) {
|
||||
this.milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
public MCTSAI(MCTSAI<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MCTSAI<T> deepCopy() {
|
||||
return new MCTSAI<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
Node root = new Node(game.deepCopy());
|
||||
|
||||
long endTime = System.currentTimeMillis() + milliseconds;
|
||||
|
||||
while (System.currentTimeMillis() <= endTime) {
|
||||
Node node = selection(root);
|
||||
long legalMoves = node.state.getLegalMoves();
|
||||
|
||||
if (legalMoves != 0) {
|
||||
node = expansion(node, legalMoves);
|
||||
}
|
||||
|
||||
float result = 0.0f;
|
||||
|
||||
if (node.state.getLegalMoves() != 0) {
|
||||
result = simulation(node.state, game.getCurrentTurn());
|
||||
}
|
||||
|
||||
backPropagation(node, result);
|
||||
}
|
||||
|
||||
int mostVisitedIndex = -1;
|
||||
int mostVisits = -1;
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].visits > mostVisits) {
|
||||
mostVisitedIndex = i;
|
||||
mostVisits = root.children[i].visits;
|
||||
}
|
||||
}
|
||||
|
||||
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
|
||||
}
|
||||
|
||||
private Node selection(Node node) {
|
||||
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
|
||||
node = node.bestUCTChild();
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
private Node expansion(Node node, long legalMoves) {
|
||||
for (int i = 0; i < node.expanded; i++) {
|
||||
legalMoves &= ~node.children[i].move;
|
||||
}
|
||||
|
||||
if (legalMoves == 0L) {
|
||||
return node;
|
||||
}
|
||||
|
||||
long move = randomSetBit(legalMoves);
|
||||
|
||||
TurnBasedGame<?> copy = node.state.deepCopy();
|
||||
copy.play(move);
|
||||
|
||||
Node newlyExpanded = new Node(copy, move, node);
|
||||
|
||||
node.children[node.expanded] = newlyExpanded;
|
||||
node.expanded++;
|
||||
|
||||
return newlyExpanded;
|
||||
}
|
||||
|
||||
private float simulation(TurnBasedGame<?> state, int playerIndex) {
|
||||
TurnBasedGame<?> copy = state.deepCopy();
|
||||
long legalMoves = copy.getLegalMoves();
|
||||
PlayResult result = null;
|
||||
|
||||
while (legalMoves != 0) {
|
||||
result = copy.play(randomSetBit(legalMoves));
|
||||
legalMoves = copy.getLegalMoves();
|
||||
}
|
||||
|
||||
if (result.state() == GameState.WIN) {
|
||||
if (result.player() == playerIndex) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
return -0.2f;
|
||||
}
|
||||
|
||||
private void backPropagation(Node node, float value) {
|
||||
while (node != null) {
|
||||
node.visits++;
|
||||
node.value += value;
|
||||
node = node.parent;
|
||||
}
|
||||
}
|
||||
|
||||
public static long randomSetBit(long value) {
|
||||
Random random = new Random();
|
||||
|
||||
int count = Long.bitCount(value);
|
||||
int target = random.nextInt(count);
|
||||
|
||||
while (true) {
|
||||
int bit = Long.numberOfTrailingZeros(value);
|
||||
if (target == 0) {
|
||||
return 1L << bit;
|
||||
}
|
||||
value &= value - 1;
|
||||
target--;
|
||||
}
|
||||
}
|
||||
}
|
||||
250
game/src/main/java/org/toop/game/players/ai/MCTSAI1.java
Normal file
250
game/src/main/java/org/toop/game/players/ai/MCTSAI1.java
Normal file
@@ -0,0 +1,250 @@
|
||||
package org.toop.game.players.ai;
|
||||
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class MCTSAI1<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
public TurnBasedGame<?> state;
|
||||
|
||||
public long move;
|
||||
public long unexpandedMoves;
|
||||
|
||||
public Node parent;
|
||||
|
||||
public Node[] children;
|
||||
public int expanded;
|
||||
|
||||
public float value;
|
||||
public int visits;
|
||||
|
||||
public boolean solved;
|
||||
public float solvedValue;
|
||||
|
||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||
final long legalMoves = state.getLegalMoves();
|
||||
|
||||
this.state = state;
|
||||
|
||||
this.move = move;
|
||||
this.unexpandedMoves = legalMoves;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.children = new Node[Long.bitCount(legalMoves)];
|
||||
this.expanded = 0;
|
||||
|
||||
this.value = 0.0f;
|
||||
this.visits = 0;
|
||||
|
||||
this.solved = false;
|
||||
this.solvedValue = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
this(state, null, 0L);
|
||||
}
|
||||
|
||||
public boolean isFullyExpanded() {
|
||||
return expanded == children.length;
|
||||
}
|
||||
|
||||
public float calculateUCT(int parentVisits) {
|
||||
if (visits == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
final float exploitation = value / visits;
|
||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||
|
||||
return exploitation + exploration;
|
||||
}
|
||||
|
||||
public Node bestUCTChild() {
|
||||
Node highestUCTChild = null;
|
||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||
|
||||
for (int i = 0; i < expanded; i++) {
|
||||
final float childUCT = children[i].calculateUCT(visits);
|
||||
|
||||
if (childUCT > highestUCT) {
|
||||
highestUCTChild = children[i];
|
||||
highestUCT = childUCT;
|
||||
}
|
||||
}
|
||||
|
||||
return highestUCTChild;
|
||||
}
|
||||
}
|
||||
|
||||
private static final Random random = new Random();
|
||||
|
||||
private final int milliseconds;
|
||||
|
||||
public MCTSAI1(int milliseconds) {
|
||||
this.milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
public MCTSAI1(MCTSAI1<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MCTSAI1<T> deepCopy() {
|
||||
return new MCTSAI1<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
final Node root = new Node(game, null, 0L);
|
||||
|
||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||
|
||||
while (System.nanoTime() < endTime) {
|
||||
Node leaf = selection(root);
|
||||
leaf = expansion(leaf);
|
||||
final float value = simulation(leaf);
|
||||
backPropagation(leaf, value);
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
return mostVisitedChild.move;
|
||||
}
|
||||
|
||||
private Node mostVisitedChild(Node root) {
|
||||
Node mostVisitedChild = null;
|
||||
int mostVisited = -1;
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].visits > mostVisited) {
|
||||
mostVisitedChild = root.children[i];
|
||||
mostVisited = root.children[i].visits;
|
||||
}
|
||||
}
|
||||
|
||||
return mostVisitedChild;
|
||||
}
|
||||
|
||||
private Node selection(Node root) {
|
||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
root = root.bestUCTChild();
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
private Node expansion(Node leaf) {
|
||||
if (leaf.unexpandedMoves == 0L) {
|
||||
return leaf;
|
||||
}
|
||||
|
||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
||||
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
copiedState.play(unexpandedMove);
|
||||
|
||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||
|
||||
leaf.children[leaf.expanded] = expandedChild;
|
||||
leaf.expanded++;
|
||||
|
||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||
|
||||
return expandedChild;
|
||||
}
|
||||
|
||||
private float simulation(Node leaf) {
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||
|
||||
while (!copiedState.isTerminal()) {
|
||||
final long legalMoves = copiedState.getLegalMoves();
|
||||
final long randomMove = randomSetBit(legalMoves);
|
||||
|
||||
copiedState.play(randomMove);
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() == playerIndex) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() >= 0) {
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
private void backPropagation(Node leaf, float value) {
|
||||
while (leaf != null) {
|
||||
leaf.value += value;
|
||||
leaf.visits++;
|
||||
|
||||
if (!leaf.solved) {
|
||||
updateSolvedStatus(leaf);
|
||||
}
|
||||
|
||||
value = -value;
|
||||
leaf = leaf.parent;
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSolvedStatus(Node node) {
|
||||
if (node.state.isTerminal()) {
|
||||
node.solved = true;
|
||||
|
||||
final int winner = node.state.getWinner();
|
||||
final int mover = 1 - node.state.getCurrentTurn();
|
||||
|
||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.isFullyExpanded()) {
|
||||
boolean allChildrenSolved = true;
|
||||
boolean foundWinningMove = false;
|
||||
boolean foundDrawMove = false;
|
||||
|
||||
for (final Node child : node.children) {
|
||||
if (child.solved) {
|
||||
if (child.solvedValue == -1.0f) {
|
||||
foundWinningMove = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (child.solvedValue == 0.0f) {
|
||||
foundDrawMove = true;
|
||||
}
|
||||
} else {
|
||||
allChildrenSolved = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundWinningMove) {
|
||||
node.solved = true;
|
||||
node.solvedValue = 1.0f;
|
||||
} else if (allChildrenSolved) {
|
||||
node.solved = true;
|
||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private long randomSetBit(long value) {
|
||||
if (0L == value) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
final int bitCount = Long.bitCount(value);
|
||||
final int randomBitCount = random.nextInt(bitCount);
|
||||
|
||||
for (int i = 0; i < randomBitCount; i++) {
|
||||
value &= value - 1;
|
||||
}
|
||||
|
||||
return value & -value;
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,9 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
public float value;
|
||||
public int visits;
|
||||
|
||||
public boolean solved;
|
||||
public float solvedValue;
|
||||
|
||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||
final long legalMoves = state.getLegalMoves();
|
||||
|
||||
@@ -35,6 +38,9 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
this.value = 0.0f;
|
||||
this.visits = 0;
|
||||
|
||||
this.solved = false;
|
||||
this.solvedValue = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
@@ -46,6 +52,10 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
}
|
||||
|
||||
public float calculateUCT(int parentVisits) {
|
||||
if (visits == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
final float exploitation = value / visits;
|
||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||
|
||||
@@ -63,24 +73,28 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
highestUCTChild = children[i];
|
||||
highestUCT = childUCT;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return highestUCTChild;
|
||||
}
|
||||
}
|
||||
|
||||
private final Random random;
|
||||
private static final Random random = new Random();
|
||||
|
||||
private final int milliseconds;
|
||||
|
||||
private Node root;
|
||||
|
||||
public MCTSAI2(int milliseconds) {
|
||||
this.random = new Random();
|
||||
this.milliseconds = milliseconds;
|
||||
|
||||
this.root = null;
|
||||
}
|
||||
|
||||
public MCTSAI2(MCTSAI2<?> other) {
|
||||
this.random = other.random;
|
||||
public MCTSAI2(MCTSAI2<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
|
||||
this.root = other.root;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -90,7 +104,7 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
final Node root = new Node(game, null, 0L);
|
||||
root = findOrResetRoot(root, game);
|
||||
|
||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||
|
||||
@@ -102,8 +116,11 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
final long move = mostVisitedChild.move;
|
||||
|
||||
return mostVisitedChild != null? mostVisitedChild.move : 0L;
|
||||
root = findChildByMove(root, move);
|
||||
|
||||
return move;
|
||||
}
|
||||
|
||||
private Node mostVisitedChild(Node root) {
|
||||
@@ -120,8 +137,51 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
return mostVisitedChild;
|
||||
}
|
||||
|
||||
private Node findOrResetRoot(Node root, T game) {
|
||||
if (root == null) {
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
||||
return root;
|
||||
}
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
private Node findChildByMove(Node root, long move) {
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].move == move) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
||||
if (state1.length != state2.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < state1.length; i++) {
|
||||
if (state1[i] != state2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
private Node selection(Node root) {
|
||||
while (root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
root = root.bestUCTChild();
|
||||
}
|
||||
|
||||
@@ -161,7 +221,9 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
if (copiedState.getWinner() == playerIndex) {
|
||||
return 1.0f;
|
||||
} else if (copiedState.getWinner() >= 0) {
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() >= 0) {
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
@@ -173,11 +235,57 @@ public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
leaf.value += value;
|
||||
leaf.visits++;
|
||||
|
||||
if (!leaf.solved) {
|
||||
updateSolvedStatus(leaf);
|
||||
}
|
||||
|
||||
value = -value;
|
||||
leaf = leaf.parent;
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSolvedStatus(Node node) {
|
||||
if (node.state.isTerminal()) {
|
||||
node.solved = true;
|
||||
|
||||
final int winner = node.state.getWinner();
|
||||
final int mover = 1 - node.state.getCurrentTurn();
|
||||
|
||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.isFullyExpanded()) {
|
||||
boolean allChildrenSolved = true;
|
||||
boolean foundWinningMove = false;
|
||||
boolean foundDrawMove = false;
|
||||
|
||||
for (final Node child : node.children) {
|
||||
if (child.solved) {
|
||||
if (child.solvedValue == -1.0f) {
|
||||
foundWinningMove = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (child.solvedValue == 0.0f) {
|
||||
foundDrawMove = true;
|
||||
}
|
||||
} else {
|
||||
allChildrenSolved = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundWinningMove) {
|
||||
node.solved = true;
|
||||
node.solvedValue = 1.0f;
|
||||
} else if (allChildrenSolved) {
|
||||
node.solved = true;
|
||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private long randomSetBit(long value) {
|
||||
if (0L == value) {
|
||||
return 0;
|
||||
|
||||
@@ -3,7 +3,13 @@ package org.toop.game.players.ai;
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
@@ -20,6 +26,9 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
public float value;
|
||||
public int visits;
|
||||
|
||||
public boolean solved;
|
||||
public float solvedValue;
|
||||
|
||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||
final long legalMoves = state.getLegalMoves();
|
||||
|
||||
@@ -35,6 +44,9 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
this.value = 0.0f;
|
||||
this.visits = 0;
|
||||
|
||||
this.solved = false;
|
||||
this.solvedValue = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
@@ -46,6 +58,10 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
}
|
||||
|
||||
public float calculateUCT(int parentVisits) {
|
||||
if (visits == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
final float exploitation = value / visits;
|
||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||
|
||||
@@ -63,30 +79,25 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
highestUCTChild = children[i];
|
||||
highestUCT = childUCT;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return highestUCTChild;
|
||||
}
|
||||
}
|
||||
|
||||
private final Random random;
|
||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
||||
|
||||
private Node root;
|
||||
private final int milliseconds;
|
||||
private final int threads;
|
||||
|
||||
public MCTSAI3(int milliseconds) {
|
||||
this.random = new Random();
|
||||
|
||||
this.root = null;
|
||||
public MCTSAI3(int milliseconds, int threads) {
|
||||
this.milliseconds = milliseconds;
|
||||
this.threads = threads;
|
||||
}
|
||||
|
||||
public MCTSAI3(MCTSAI3<?> other) {
|
||||
this.random = other.random;
|
||||
|
||||
this.root = other.root;
|
||||
public MCTSAI3(MCTSAI3<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
this.threads = other.threads;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -96,23 +107,57 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
detectRoot(game);
|
||||
|
||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||
|
||||
while (System.nanoTime() < endTime) {
|
||||
Node leaf = selection(root);
|
||||
leaf = expansion(leaf);
|
||||
final float value = simulation(leaf);
|
||||
backPropagation(leaf, value);
|
||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < threads; i++) {
|
||||
tasks.add(() -> {
|
||||
final Node localRoot = new Node(game.deepCopy());
|
||||
|
||||
while (System.nanoTime() < endTime) {
|
||||
Node leaf = selection(localRoot);
|
||||
leaf = expansion(leaf);
|
||||
final float value = simulation(leaf);
|
||||
backPropagation(leaf, value);
|
||||
}
|
||||
|
||||
return localRoot;
|
||||
});
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
final long move = mostVisitedChild != null? mostVisitedChild.move : 0L;
|
||||
try {
|
||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
||||
|
||||
newRoot(move);
|
||||
pool.shutdown();
|
||||
|
||||
return move;
|
||||
final Node root = new Node(game.deepCopy());
|
||||
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
expansion(root);
|
||||
}
|
||||
|
||||
for (final Future<Node> result : results) {
|
||||
final Node localRoot = result.get();
|
||||
|
||||
for (final Node localChild : localRoot.children) {
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
if (localChild.move == root.children[i].move) {
|
||||
root.children[i].visits += localChild.visits;
|
||||
root.visits += localChild.visits;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
return mostVisitedChild.move;
|
||||
} catch (Exception _) {
|
||||
final long legalMoves = game.getLegalMoves();
|
||||
return randomSetBit(legalMoves);
|
||||
}
|
||||
}
|
||||
|
||||
private Node mostVisitedChild(Node root) {
|
||||
@@ -129,62 +174,8 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
return mostVisitedChild;
|
||||
}
|
||||
|
||||
private void detectRoot(T game) {
|
||||
if (root == null) {
|
||||
root = new Node(game.deepCopy());
|
||||
return;
|
||||
}
|
||||
|
||||
final long[] currentBoards = game.getBoard();
|
||||
final long[] rootBoards = root.state.getBoard();
|
||||
|
||||
boolean detected = true;
|
||||
|
||||
for (int i = 0; i < rootBoards.length; i++) {
|
||||
if (rootBoards[i] != currentBoards[i]) {
|
||||
detected = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (detected) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
final Node child = root.children[i];
|
||||
|
||||
final long[] childBoards = child.state.getBoard();
|
||||
|
||||
detected = true;
|
||||
|
||||
for (int j = 0; j < childBoards.length; j++) {
|
||||
if (childBoards[j] != currentBoards[j]) {
|
||||
detected = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (detected) {
|
||||
root = child;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
root = new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
private void newRoot(long move) {
|
||||
for (final Node child : root.children) {
|
||||
if (child.move == move) {
|
||||
root = child;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Node selection(Node root) {
|
||||
while (root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
root = root.bestUCTChild();
|
||||
}
|
||||
|
||||
@@ -224,7 +215,9 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
if (copiedState.getWinner() == playerIndex) {
|
||||
return 1.0f;
|
||||
} else if (copiedState.getWinner() >= 0) {
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() >= 0) {
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
@@ -236,18 +229,64 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
leaf.value += value;
|
||||
leaf.visits++;
|
||||
|
||||
if (!leaf.solved) {
|
||||
updateSolvedStatus(leaf);
|
||||
}
|
||||
|
||||
value = -value;
|
||||
leaf = leaf.parent;
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSolvedStatus(Node node) {
|
||||
if (node.state.isTerminal()) {
|
||||
node.solved = true;
|
||||
|
||||
final int winner = node.state.getWinner();
|
||||
final int mover = 1 - node.state.getCurrentTurn();
|
||||
|
||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.isFullyExpanded()) {
|
||||
boolean allChildrenSolved = true;
|
||||
boolean foundWinningMove = false;
|
||||
boolean foundDrawMove = false;
|
||||
|
||||
for (final Node child : node.children) {
|
||||
if (child.solved) {
|
||||
if (child.solvedValue == -1.0f) {
|
||||
foundWinningMove = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (child.solvedValue == 0.0f) {
|
||||
foundDrawMove = true;
|
||||
}
|
||||
} else {
|
||||
allChildrenSolved = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundWinningMove) {
|
||||
node.solved = true;
|
||||
node.solvedValue = 1.0f;
|
||||
} else if (allChildrenSolved) {
|
||||
node.solved = true;
|
||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private long randomSetBit(long value) {
|
||||
if (0L == value) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
final int bitCount = Long.bitCount(value);
|
||||
final int randomBitCount = random.nextInt(bitCount);
|
||||
final int randomBitCount = random.get().nextInt(bitCount);
|
||||
|
||||
for (int i = 0; i < randomBitCount; i++) {
|
||||
value &= value - 1;
|
||||
@@ -255,4 +294,4 @@ public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
|
||||
return value & -value;
|
||||
}
|
||||
}
|
||||
}
|
||||
359
game/src/main/java/org/toop/game/players/ai/MCTSAI4.java
Normal file
359
game/src/main/java/org/toop/game/players/ai/MCTSAI4.java
Normal file
@@ -0,0 +1,359 @@
|
||||
package org.toop.game.players.ai;
|
||||
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
public class MCTSAI4<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
public TurnBasedGame<?> state;
|
||||
|
||||
public long move;
|
||||
public long unexpandedMoves;
|
||||
|
||||
public Node parent;
|
||||
|
||||
public Node[] children;
|
||||
public int expanded;
|
||||
|
||||
public float value;
|
||||
public int visits;
|
||||
|
||||
public boolean solved;
|
||||
public float solvedValue;
|
||||
|
||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||
final long legalMoves = state.getLegalMoves();
|
||||
|
||||
this.state = state;
|
||||
|
||||
this.move = move;
|
||||
this.unexpandedMoves = legalMoves;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.children = new Node[Long.bitCount(legalMoves)];
|
||||
this.expanded = 0;
|
||||
|
||||
this.value = 0.0f;
|
||||
this.visits = 0;
|
||||
|
||||
this.solved = false;
|
||||
this.solvedValue = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
this(state, null, 0L);
|
||||
}
|
||||
|
||||
public boolean isFullyExpanded() {
|
||||
return expanded == children.length;
|
||||
}
|
||||
|
||||
public float calculateUCT(int parentVisits) {
|
||||
if (visits == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
final float exploitation = value / visits;
|
||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||
|
||||
return exploitation + exploration;
|
||||
}
|
||||
|
||||
public Node bestUCTChild() {
|
||||
Node highestUCTChild = null;
|
||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||
|
||||
for (int i = 0; i < expanded; i++) {
|
||||
final float childUCT = children[i].calculateUCT(visits);
|
||||
|
||||
if (childUCT > highestUCT) {
|
||||
highestUCTChild = children[i];
|
||||
highestUCT = childUCT;
|
||||
}
|
||||
}
|
||||
|
||||
return highestUCTChild;
|
||||
}
|
||||
}
|
||||
|
||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
||||
|
||||
private final int milliseconds;
|
||||
private final int threads;
|
||||
|
||||
private final Node[] threadRoots;
|
||||
|
||||
public MCTSAI4(int milliseconds, int threads) {
|
||||
this.milliseconds = milliseconds;
|
||||
this.threads = threads;
|
||||
|
||||
this.threadRoots = new Node[threads];
|
||||
}
|
||||
|
||||
public MCTSAI4(MCTSAI4<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
this.threads = other.threads;
|
||||
|
||||
this.threadRoots = other.threadRoots;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MCTSAI4<T> deepCopy() {
|
||||
return new MCTSAI4<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
for (int i = 0; i < threads; i++) {
|
||||
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
|
||||
}
|
||||
|
||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||
|
||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < threads; i++) {
|
||||
final int threadIndex = i;
|
||||
|
||||
tasks.add(() -> {
|
||||
final Node localRoot = threadRoots[threadIndex];
|
||||
|
||||
while (System.nanoTime() < endTime) {
|
||||
Node leaf = selection(localRoot);
|
||||
leaf = expansion(leaf);
|
||||
final float value = simulation(leaf);
|
||||
backPropagation(leaf, value);
|
||||
}
|
||||
|
||||
return localRoot;
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
||||
|
||||
pool.shutdown();
|
||||
|
||||
final Node root = new Node(game.deepCopy());
|
||||
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
expansion(root);
|
||||
}
|
||||
|
||||
for (final Future<Node> result : results) {
|
||||
final Node localRoot = result.get();
|
||||
|
||||
for (final Node localChild : localRoot.children) {
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
if (localChild.move == root.children[i].move) {
|
||||
root.children[i].visits += localChild.visits;
|
||||
root.visits += localChild.visits;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
final long move = mostVisitedChild.move;
|
||||
|
||||
for (int i = 0; i < threads; i++) {
|
||||
threadRoots[i] = findChildByMove(threadRoots[i], move);
|
||||
}
|
||||
|
||||
return move;
|
||||
} catch (Exception _) {
|
||||
final long legalMoves = game.getLegalMoves();
|
||||
return randomSetBit(legalMoves);
|
||||
}
|
||||
}
|
||||
|
||||
private Node mostVisitedChild(Node root) {
|
||||
Node mostVisitedChild = null;
|
||||
int mostVisited = -1;
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].visits > mostVisited) {
|
||||
mostVisitedChild = root.children[i];
|
||||
mostVisited = root.children[i].visits;
|
||||
}
|
||||
}
|
||||
|
||||
return mostVisitedChild;
|
||||
}
|
||||
|
||||
private Node findOrResetRoot(Node root, T game) {
|
||||
if (root == null) {
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
||||
return root;
|
||||
}
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
private Node findChildByMove(Node root, long move) {
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].move == move) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
||||
if (state1.length != state2.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < state1.length; i++) {
|
||||
if (state1[i] != state2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private Node selection(Node root) {
|
||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
root = root.bestUCTChild();
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
private Node expansion(Node leaf) {
|
||||
if (leaf.unexpandedMoves == 0L) {
|
||||
return leaf;
|
||||
}
|
||||
|
||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
||||
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
copiedState.play(unexpandedMove);
|
||||
|
||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||
|
||||
leaf.children[leaf.expanded] = expandedChild;
|
||||
leaf.expanded++;
|
||||
|
||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||
|
||||
return expandedChild;
|
||||
}
|
||||
|
||||
private float simulation(Node leaf) {
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||
|
||||
while (!copiedState.isTerminal()) {
|
||||
final long legalMoves = copiedState.getLegalMoves();
|
||||
final long randomMove = randomSetBit(legalMoves);
|
||||
|
||||
copiedState.play(randomMove);
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() == playerIndex) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() >= 0) {
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
private void backPropagation(Node leaf, float value) {
|
||||
while (leaf != null) {
|
||||
leaf.value += value;
|
||||
leaf.visits++;
|
||||
|
||||
if (!leaf.solved) {
|
||||
updateSolvedStatus(leaf);
|
||||
}
|
||||
|
||||
value = -value;
|
||||
leaf = leaf.parent;
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSolvedStatus(Node node) {
|
||||
if (node.state.isTerminal()) {
|
||||
node.solved = true;
|
||||
|
||||
final int winner = node.state.getWinner();
|
||||
final int mover = 1 - node.state.getCurrentTurn();
|
||||
|
||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.isFullyExpanded()) {
|
||||
boolean allChildrenSolved = true;
|
||||
boolean foundWinningMove = false;
|
||||
boolean foundDrawMove = false;
|
||||
|
||||
for (final Node child : node.children) {
|
||||
if (child.solved) {
|
||||
if (child.solvedValue == -1.0f) {
|
||||
foundWinningMove = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (child.solvedValue == 0.0f) {
|
||||
foundDrawMove = true;
|
||||
}
|
||||
} else {
|
||||
allChildrenSolved = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundWinningMove) {
|
||||
node.solved = true;
|
||||
node.solvedValue = 1.0f;
|
||||
} else if (allChildrenSolved) {
|
||||
node.solved = true;
|
||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private long randomSetBit(long value) {
|
||||
if (0L == value) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
final int bitCount = Long.bitCount(value);
|
||||
final int randomBitCount = random.get().nextInt(bitCount);
|
||||
|
||||
for (int i = 0; i < randomBitCount; i++) {
|
||||
value &= value - 1;
|
||||
}
|
||||
|
||||
return value & -value;
|
||||
}
|
||||
}
|
||||
371
game/src/main/java/org/toop/game/players/ai/MCTSAI5.java
Normal file
371
game/src/main/java/org/toop/game/players/ai/MCTSAI5.java
Normal file
@@ -0,0 +1,371 @@
|
||||
package org.toop.game.players.ai;
|
||||
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
public class MCTSAI5<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
public TurnBasedGame<?> state;
|
||||
|
||||
public long move;
|
||||
public long unexpandedMoves;
|
||||
|
||||
public Node parent;
|
||||
|
||||
public Node[] children;
|
||||
public int expanded;
|
||||
|
||||
public float value;
|
||||
public int visits;
|
||||
|
||||
public boolean solved;
|
||||
public float solvedValue;
|
||||
|
||||
public float heuristic;
|
||||
|
||||
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||
final long legalMoves = state.getLegalMoves();
|
||||
|
||||
this.state = state;
|
||||
|
||||
this.move = move;
|
||||
this.unexpandedMoves = legalMoves;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.children = new Node[Long.bitCount(legalMoves)];
|
||||
this.expanded = 0;
|
||||
|
||||
this.value = 0.0f;
|
||||
this.visits = 0;
|
||||
|
||||
this.solved = false;
|
||||
this.solvedValue = 0.0f;
|
||||
|
||||
this.heuristic = state.rateMove(move);
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
this(state, null, 0L);
|
||||
}
|
||||
|
||||
public boolean isFullyExpanded() {
|
||||
return expanded == children.length;
|
||||
}
|
||||
|
||||
public float calculateUCT(int parentVisits) {
|
||||
if (visits == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
final float exploitation = value / visits;
|
||||
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||
final float bias = heuristic / visits;
|
||||
|
||||
return exploitation + exploration + bias;
|
||||
}
|
||||
|
||||
public Node bestUCTChild() {
|
||||
Node highestUCTChild = null;
|
||||
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||
|
||||
for (int i = 0; i < expanded; i++) {
|
||||
final float childUCT = children[i].calculateUCT(visits);
|
||||
|
||||
if (childUCT > highestUCT) {
|
||||
highestUCTChild = children[i];
|
||||
highestUCT = childUCT;
|
||||
}
|
||||
}
|
||||
|
||||
return highestUCTChild;
|
||||
}
|
||||
}
|
||||
|
||||
private static final ThreadLocal<Random> random = ThreadLocal.withInitial(Random::new);
|
||||
|
||||
private final int milliseconds;
|
||||
private final int threads;
|
||||
|
||||
private final Node[] threadRoots;
|
||||
|
||||
public MCTSAI5(int milliseconds, int threads) {
|
||||
this.milliseconds = milliseconds;
|
||||
this.threads = threads;
|
||||
|
||||
this.threadRoots = new Node[threads];
|
||||
}
|
||||
|
||||
public MCTSAI5(MCTSAI5<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
this.threads = other.threads;
|
||||
|
||||
this.threadRoots = other.threadRoots;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MCTSAI5<T> deepCopy() {
|
||||
return new MCTSAI5<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
for (int i = 0; i < threads; i++) {
|
||||
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
|
||||
}
|
||||
|
||||
final ExecutorService pool = Executors.newFixedThreadPool(threads);
|
||||
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||
|
||||
final List<Callable<Node>> tasks = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < threads; i++) {
|
||||
final int threadIndex = i;
|
||||
|
||||
tasks.add(() -> {
|
||||
final Node localRoot = threadRoots[threadIndex];
|
||||
|
||||
while (System.nanoTime() < endTime) {
|
||||
Node leaf = selection(localRoot);
|
||||
leaf = expansion(leaf);
|
||||
final float value = simulation(leaf);
|
||||
backPropagation(leaf, value);
|
||||
}
|
||||
|
||||
return localRoot;
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
final List<Future<Node>> results = pool.invokeAll(tasks);
|
||||
|
||||
pool.shutdown();
|
||||
|
||||
final Node root = new Node(game.deepCopy());
|
||||
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
expansion(root);
|
||||
}
|
||||
|
||||
for (final Future<Node> result : results) {
|
||||
final Node localRoot = result.get();
|
||||
|
||||
for (final Node localChild : localRoot.children) {
|
||||
for (int i = 0; i < root.children.length; i++) {
|
||||
if (localChild.move == root.children[i].move) {
|
||||
root.children[i].visits += localChild.visits;
|
||||
root.visits += localChild.visits;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final Node mostVisitedChild = mostVisitedChild(root);
|
||||
final long move = mostVisitedChild.move;
|
||||
|
||||
for (int i = 0; i < threads; i++) {
|
||||
threadRoots[i] = findChildByMove(threadRoots[i], move);
|
||||
}
|
||||
|
||||
return move;
|
||||
} catch (Exception _) {
|
||||
final long legalMoves = game.getLegalMoves();
|
||||
return randomSetBit(legalMoves);
|
||||
}
|
||||
}
|
||||
|
||||
private Node mostVisitedChild(Node root) {
|
||||
Node mostVisitedChild = null;
|
||||
int mostVisited = -1;
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].visits > mostVisited) {
|
||||
mostVisitedChild = root.children[i];
|
||||
mostVisited = root.children[i].visits;
|
||||
}
|
||||
}
|
||||
|
||||
return mostVisitedChild;
|
||||
}
|
||||
|
||||
private Node findOrResetRoot(Node root, T game) {
|
||||
if (root == null) {
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
if (areStatesEqual(root.state.getBoard(), game.getBoard())) {
|
||||
return root;
|
||||
}
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (areStatesEqual(root.children[i].state.getBoard(), game.getBoard())) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return new Node(game.deepCopy());
|
||||
}
|
||||
|
||||
private Node findChildByMove(Node root, long move) {
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].move == move) {
|
||||
root.children[i].parent = null;
|
||||
return root.children[i];
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private boolean areStatesEqual(long[] state1, long[] state2) {
|
||||
if (state1.length != state2.length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < state1.length; i++) {
|
||||
if (state1[i] != state2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private Node selection(Node root) {
|
||||
while (!root.solved && root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||
root = root.bestUCTChild();
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
private Node expansion(Node leaf) {
|
||||
if (leaf.unexpandedMoves == 0L) {
|
||||
return leaf;
|
||||
}
|
||||
|
||||
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
||||
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
copiedState.play(unexpandedMove);
|
||||
|
||||
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||
|
||||
leaf.children[leaf.expanded] = expandedChild;
|
||||
leaf.expanded++;
|
||||
|
||||
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||
|
||||
return expandedChild;
|
||||
}
|
||||
|
||||
private float simulation(Node leaf) {
|
||||
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||
|
||||
while (!copiedState.isTerminal()) {
|
||||
final long legalMoves = copiedState.getLegalMoves();
|
||||
|
||||
long move = 0L;
|
||||
|
||||
if (random.get().nextFloat() > 0.9f) {
|
||||
move = copiedState.heuristicMove(legalMoves);
|
||||
} else {
|
||||
move = randomSetBit(legalMoves);
|
||||
}
|
||||
|
||||
copiedState.play(move);
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() == playerIndex) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
if (copiedState.getWinner() >= 0) {
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
private void backPropagation(Node leaf, float value) {
|
||||
while (leaf != null) {
|
||||
leaf.value += value;
|
||||
leaf.visits++;
|
||||
|
||||
if (!leaf.solved) {
|
||||
updateSolvedStatus(leaf);
|
||||
}
|
||||
|
||||
value = -value;
|
||||
leaf = leaf.parent;
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSolvedStatus(Node node) {
|
||||
if (node.state.isTerminal()) {
|
||||
node.solved = true;
|
||||
|
||||
final int winner = node.state.getWinner();
|
||||
final int mover = 1 - node.state.getCurrentTurn();
|
||||
|
||||
node.solvedValue = winner == mover? 1.0f : winner == -1? 0.0f : -1.0f;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (node.isFullyExpanded()) {
|
||||
boolean allChildrenSolved = true;
|
||||
boolean foundWinningMove = false;
|
||||
boolean foundDrawMove = false;
|
||||
|
||||
for (final Node child : node.children) {
|
||||
if (child.solved) {
|
||||
if (child.solvedValue == -1.0f) {
|
||||
foundWinningMove = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (child.solvedValue == 0.0f) {
|
||||
foundDrawMove = true;
|
||||
}
|
||||
} else {
|
||||
allChildrenSolved = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (foundWinningMove) {
|
||||
node.solved = true;
|
||||
node.solvedValue = 1.0f;
|
||||
} else if (allChildrenSolved) {
|
||||
node.solved = true;
|
||||
node.solvedValue = foundDrawMove? 0.0f : -1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private long randomSetBit(long value) {
|
||||
if (0L == value) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
final int bitCount = Long.bitCount(value);
|
||||
final int randomBitCount = random.get().nextInt(bitCount);
|
||||
|
||||
for (int i = 0; i < randomBitCount; i++) {
|
||||
value &= value - 1;
|
||||
}
|
||||
|
||||
return value & -value;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user