bitboard fix & mcts v2 & mcts v3. v3 still in progress and v4 coming soon

This commit is contained in:
ramollia
2026-01-07 14:39:38 +01:00
parent e149588b60
commit df93b44d19
10 changed files with 568 additions and 29 deletions

View File

@@ -1,5 +1,7 @@
package org.toop.game;
import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player;
@@ -11,6 +13,8 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
private final int columnSize;
private final int rowSize;
protected PlayResult state;
private Player<T>[] players;
// long is 64 bits. Every game has a limit of 64 cells maximum.
@@ -20,6 +24,9 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
public BitboardGame(int columnSize, int rowSize, int playerCount, Player<T>[] players) {
this.columnSize = columnSize;
this.rowSize = rowSize;
this.state = new PlayResult(GameState.NORMAL, -1);
this.players = players;
this.playerBitboard = new long[playerCount];
@@ -30,6 +37,8 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
this.columnSize = other.columnSize;
this.rowSize = other.rowSize;
this.state = other.state;
this.playerBitboard = other.playerBitboard.clone();
this.currentTurn = other.currentTurn;
this.players = Arrays.stream(other.players)
@@ -61,7 +70,9 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
return getCurrentPlayerIndex();
}
public Player<T> getPlayer(int index) {return players[index];}
public Player<T> getPlayer(int index) {
return players[index];
}
public int getCurrentPlayerIndex() {
return currentTurn % playerBitboard.length;
@@ -75,9 +86,17 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
return players[getCurrentPlayerIndex()];
}
@Override
public PlayResult getState() {
return state;
}
@Override
public boolean isTerminal() {
return state.state() == GameState.WIN || state.state() == GameState.DRAW;
}
@Override
@Override
public long[] getBoard() {return this.playerBitboard;}
public void nextTurn() {

View File

@@ -175,7 +175,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction << 1) & mask;
direction |= (direction << 1) & mask;
if (((direction << 1) & player) != 0) {
if (((direction << 1) & player & notAFile) != 0) {
flips |= direction;
}
@@ -189,7 +189,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction >>> 1) & mask;
direction |= (direction >>> 1) & mask;
if (((direction >>> 1) & player) != 0) {
if (((direction >>> 1) & player & notHFile) != 0) {
flips |= direction;
}
@@ -203,7 +203,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction << 9) & mask;
direction |= (direction << 9) & mask;
if (((direction << 9) & player) != 0) {
if (((direction << 9) & player & notAFile) != 0) {
flips |= direction;
}
@@ -217,7 +217,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction << 7) & mask;
direction |= (direction << 7) & mask;
if (((direction << 7) & player) != 0) {
if (((direction << 7) & player & notHFile) != 0) {
flips |= direction;
}
@@ -231,7 +231,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction >>> 7) & mask;
direction |= (direction >>> 7) & mask;
if (((direction >>> 7) & player) != 0) {
if (((direction >>> 7) & player & notAFile) != 0) {
flips |= direction;
}
@@ -245,7 +245,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
direction |= (direction >>> 9) & mask;
direction |= (direction >>> 9) & mask;
if (((direction >>> 9) & player) != 0) {
if (((direction >>> 9) & player & notHFile) != 0) {
flips |= direction;
}
@@ -280,16 +280,20 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
int winner = getWinner();
if (winner == -1) {
return new PlayResult(GameState.DRAW, -1);
state = new PlayResult(GameState.DRAW, -1);
return state;
}
return new PlayResult(GameState.WIN, winner);
state = new PlayResult(GameState.WIN, winner);
return state;
}
return new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex());
state = new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex());
return state;
}
return new PlayResult(GameState.NORMAL, getCurrentPlayerIndex());
state = new PlayResult(GameState.NORMAL, getCurrentPlayerIndex());
return state;
}
public Score getScore() {

View File

@@ -39,7 +39,8 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
public PlayResult play(long move) {
// Player loses if move is invalid
if ((move & getLegalMoves()) == 0 || Long.bitCount(move) != 1){
return new PlayResult(GameState.WIN, getNextPlayer());
state = new PlayResult(GameState.WIN, getNextPlayer());
return state;
}
// Move is legal, make move
@@ -50,7 +51,8 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
// Check if current player won
if (checkWin(playerBitboard)) {
return new PlayResult(GameState.WIN, getCurrentPlayerIndex());
state = new PlayResult(GameState.WIN, getCurrentPlayerIndex());
return state;
}
// Proceed to next turn
@@ -59,11 +61,13 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
// Check for early draw
if (getLegalMoves() == 0L || checkEarlyDraw()) {
return new PlayResult(GameState.DRAW, -1);
state = new PlayResult(GameState.DRAW, -1);
return state;
}
// Nothing weird happened, continue on as normal
return new PlayResult(GameState.NORMAL, -1);
state = new PlayResult(GameState.NORMAL, -1);
return state;
}
private boolean checkWin(long board) {

View File

@@ -41,15 +41,19 @@ public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
return expanded >= children.length;
}
public Node bestUCTChild(float explorationFactor) {
float calculateUCT() {
float exploitation = visits <= 0? 0 : value / visits;
float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
int bestChildIndex = -1;
float bestScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
float exploitation = children[i].visits <= 0? 0 : children[i].value / children[i].visits;
float exploration = explorationFactor * (float)(Math.sqrt(Math.log(visits) / (children[i].visits + 0.001f)));
float score = exploitation + exploration;
final float score = calculateUCT();
if (score > bestScore) {
bestChildIndex = i;
@@ -109,14 +113,12 @@ public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
}
}
System.out.println("Visit count: " + root.visits);
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
}
private Node selection(Node node) {
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
node = node.bestUCTChild(1.41f);
node = node.bestUCTChild();
}
return node;

View File

@@ -0,0 +1,195 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private final Random random;
private final int milliseconds;
public MCTSAI2(int milliseconds) {
this.random = new Random();
this.milliseconds = milliseconds;
}
public MCTSAI2(MCTSAI2<?> other) {
this.random = other.random;
this.milliseconds = other.milliseconds;
}
@Override
public MCTSAI2<T> deepCopy() {
return new MCTSAI2<>(this);
}
@Override
public long getMove(T game) {
final Node root = new Node(game, null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild != null? mostVisitedChild.move : 0L;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node selection(Node root) {
while (root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
} else if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
value = -value;
leaf = leaf.parent;
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -0,0 +1,258 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
private static class Node {
public TurnBasedGame<?> state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public Node(TurnBasedGame<?> state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
}
public Node(TurnBasedGame<?> state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private final Random random;
private Node root;
private final int milliseconds;
public MCTSAI3(int milliseconds) {
this.random = new Random();
this.root = null;
this.milliseconds = milliseconds;
}
public MCTSAI3(MCTSAI3<?> other) {
this.random = other.random;
this.root = other.root;
this.milliseconds = other.milliseconds;
}
@Override
public MCTSAI3<T> deepCopy() {
return new MCTSAI3<>(this);
}
@Override
public long getMove(T game) {
detectRoot(game);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild != null? mostVisitedChild.move : 0L;
newRoot(move);
return move;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private void detectRoot(T game) {
if (root == null) {
root = new Node(game.deepCopy());
return;
}
final long[] currentBoards = game.getBoard();
final long[] rootBoards = root.state.getBoard();
boolean detected = true;
for (int i = 0; i < rootBoards.length; i++) {
if (rootBoards[i] != currentBoards[i]) {
detected = false;
break;
}
}
if (detected) {
return;
}
for (int i = 0; i < root.expanded; i++) {
final Node child = root.children[i];
final long[] childBoards = child.state.getBoard();
detected = true;
for (int j = 0; j < childBoards.length; j++) {
if (childBoards[j] != currentBoards[j]) {
detected = false;
break;
}
}
if (detected) {
root = child;
return;
}
}
root = new Node(game.deepCopy());
}
private void newRoot(long move) {
for (final Node child : root.children) {
if (child.move == move) {
root = child;
break;
}
}
}
private Node selection(Node root) {
while (root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
} else if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
value = -value;
leaf = leaf.parent;
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}