mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
bitboard fix & mcts v2 & mcts v3. v3 still in progress and v4 coming soon
This commit is contained in:
@@ -1,9 +1,61 @@
|
|||||||
package org.toop;
|
package org.toop;
|
||||||
|
|
||||||
import org.toop.app.App;
|
import org.toop.app.App;
|
||||||
|
import org.toop.framework.gameFramework.model.player.AbstractPlayer;
|
||||||
|
import org.toop.framework.gameFramework.model.player.Player;
|
||||||
|
import org.toop.game.games.reversi.BitboardReversi;
|
||||||
|
import org.toop.game.games.tictactoe.BitboardTicTacToe;
|
||||||
|
import org.toop.game.players.ArtificialPlayer;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
import org.toop.game.players.ai.MCTSAI2;
|
||||||
|
import org.toop.game.players.ai.MCTSAI3;
|
||||||
|
import org.toop.game.players.ai.RandomAI;
|
||||||
|
|
||||||
public final class Main {
|
public final class Main {
|
||||||
static void main(String[] args) {
|
static void main(String[] args) {
|
||||||
App.run(args);
|
// App.run(args);
|
||||||
|
testMCTS(10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void testMCTS(int games) {
|
||||||
|
var random = new ArtificialPlayer<>(new RandomAI<BitboardReversi>(), "Random AI");
|
||||||
|
var v1 = new ArtificialPlayer<>(new MCTSAI<BitboardTicTacToe>(10), "MCTS V1 AI");
|
||||||
|
var v2 = new ArtificialPlayer<>(new MCTSAI2<BitboardTicTacToe>(10), "MCTS V2 AI");
|
||||||
|
var v2_2 = new ArtificialPlayer<>(new MCTSAI2<BitboardTicTacToe>(100), "MCTS V2_2 AI");
|
||||||
|
var v3 = new ArtificialPlayer<>(new MCTSAI3<BitboardTicTacToe>(10), "MCTS V3 AI");
|
||||||
|
|
||||||
|
testAI(games, new Player[]{ v1, v2 });
|
||||||
|
// testAI(games, new Player[]{ v1, v3 });
|
||||||
|
|
||||||
|
// testAI(games, new Player[]{ random, v3 });
|
||||||
|
// testAI(games, new Player[]{ v2, v3 });
|
||||||
|
testAI(games, new Player[]{ v2, v3 });
|
||||||
|
// testAI(games, new Player[]{ v3, v2 });
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void testAI(int games, Player<BitboardReversi>[] ais) {
|
||||||
|
int wins = 0;
|
||||||
|
int ties = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < games; i++) {
|
||||||
|
final BitboardReversi match = new BitboardReversi(ais);
|
||||||
|
|
||||||
|
while (!match.isTerminal()) {
|
||||||
|
final int currentAI = match.getCurrentTurn();
|
||||||
|
final long move = ais[currentAI].getMove(match);
|
||||||
|
|
||||||
|
match.play(move);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (match.getWinner() < 0) {
|
||||||
|
ties++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
wins += match.getWinner() == 0? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName());
|
||||||
|
System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import org.toop.app.widget.complex.ViewWidget;
|
|||||||
import org.toop.app.widget.popup.ErrorPopup;
|
import org.toop.app.widget.popup.ErrorPopup;
|
||||||
import org.toop.app.widget.tutorial.*;
|
import org.toop.app.widget.tutorial.*;
|
||||||
import org.toop.game.players.ai.MCTSAI;
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
|
import org.toop.game.players.ai.MCTSAI2;
|
||||||
|
import org.toop.game.players.ai.MCTSAI3;
|
||||||
import org.toop.game.players.ai.MiniMaxAI;
|
import org.toop.game.players.ai.MiniMaxAI;
|
||||||
import org.toop.game.players.ai.RandomAI;
|
import org.toop.game.players.ai.RandomAI;
|
||||||
import org.toop.local.AppContext;
|
import org.toop.local.AppContext;
|
||||||
@@ -55,7 +57,7 @@ public class LocalMultiplayerView extends ViewWidget {
|
|||||||
if (information.players[0].isHuman) {
|
if (information.players[0].isHuman) {
|
||||||
players[0] = new LocalPlayer<>(information.players[0].name);
|
players[0] = new LocalPlayer<>(information.players[0].name);
|
||||||
} else {
|
} else {
|
||||||
players[0] = new ArtificialPlayer<>(new RandomAI<BitboardTicTacToe>(), "Random AI");
|
players[0] = new ArtificialPlayer<>(new MCTSAI<BitboardTicTacToe>(100), "MCTS AI");
|
||||||
}
|
}
|
||||||
if (information.players[1].isHuman) {
|
if (information.players[1].isHuman) {
|
||||||
players[1] = new LocalPlayer<>(information.players[1].name);
|
players[1] = new LocalPlayer<>(information.players[1].name);
|
||||||
@@ -83,12 +85,13 @@ public class LocalMultiplayerView extends ViewWidget {
|
|||||||
if (information.players[0].isHuman) {
|
if (information.players[0].isHuman) {
|
||||||
players[0] = new LocalPlayer<>(information.players[0].name);
|
players[0] = new LocalPlayer<>(information.players[0].name);
|
||||||
} else {
|
} else {
|
||||||
players[0] = new ArtificialPlayer<>(new RandomAI<BitboardReversi>(), "Random AI");
|
// players[0] = new ArtificialPlayer<>(new RandomAI<BitboardReversi>(), "Random AI");
|
||||||
|
players[0] = new ArtificialPlayer<>(new MCTSAI3<BitboardReversi>(50), "MCTS V3 AI");
|
||||||
}
|
}
|
||||||
if (information.players[1].isHuman) {
|
if (information.players[1].isHuman) {
|
||||||
players[1] = new LocalPlayer<>(information.players[1].name);
|
players[1] = new LocalPlayer<>(information.players[1].name);
|
||||||
} else {
|
} else {
|
||||||
players[1] = new ArtificialPlayer<>(new MCTSAI<BitboardReversi>(1000), "MCTS AI");
|
players[1] = new ArtificialPlayer<>(new MCTSAI2<BitboardReversi>(50), "MCTS V2 AI");
|
||||||
}
|
}
|
||||||
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
|
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
|
||||||
new ShowEnableTutorialWidget(
|
new ShowEnableTutorialWidget(
|
||||||
|
|||||||
@@ -4,4 +4,7 @@ public interface TurnBasedGame<T extends TurnBasedGame<T>> extends Playable, Dee
|
|||||||
int getCurrentTurn();
|
int getCurrentTurn();
|
||||||
int getPlayerCount();
|
int getPlayerCount();
|
||||||
int getWinner();
|
int getWinner();
|
||||||
|
|
||||||
|
PlayResult getState();
|
||||||
|
boolean isTerminal();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
|||||||
*/
|
*/
|
||||||
public abstract class AbstractPlayer<T extends TurnBasedGame<T>> implements Player<T> {
|
public abstract class AbstractPlayer<T extends TurnBasedGame<T>> implements Player<T> {
|
||||||
|
|
||||||
private final Logger logger = LogManager.getLogger(this.getClass());
|
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package org.toop.game;
|
package org.toop.game;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.GameState;
|
||||||
|
import org.toop.framework.gameFramework.model.game.PlayResult;
|
||||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
import org.toop.framework.gameFramework.model.player.Player;
|
import org.toop.framework.gameFramework.model.player.Player;
|
||||||
|
|
||||||
@@ -11,6 +13,8 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
|||||||
private final int columnSize;
|
private final int columnSize;
|
||||||
private final int rowSize;
|
private final int rowSize;
|
||||||
|
|
||||||
|
protected PlayResult state;
|
||||||
|
|
||||||
private Player<T>[] players;
|
private Player<T>[] players;
|
||||||
|
|
||||||
// long is 64 bits. Every game has a limit of 64 cells maximum.
|
// long is 64 bits. Every game has a limit of 64 cells maximum.
|
||||||
@@ -20,6 +24,9 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
|||||||
public BitboardGame(int columnSize, int rowSize, int playerCount, Player<T>[] players) {
|
public BitboardGame(int columnSize, int rowSize, int playerCount, Player<T>[] players) {
|
||||||
this.columnSize = columnSize;
|
this.columnSize = columnSize;
|
||||||
this.rowSize = rowSize;
|
this.rowSize = rowSize;
|
||||||
|
|
||||||
|
this.state = new PlayResult(GameState.NORMAL, -1);
|
||||||
|
|
||||||
this.players = players;
|
this.players = players;
|
||||||
this.playerBitboard = new long[playerCount];
|
this.playerBitboard = new long[playerCount];
|
||||||
|
|
||||||
@@ -30,6 +37,8 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
|||||||
this.columnSize = other.columnSize;
|
this.columnSize = other.columnSize;
|
||||||
this.rowSize = other.rowSize;
|
this.rowSize = other.rowSize;
|
||||||
|
|
||||||
|
this.state = other.state;
|
||||||
|
|
||||||
this.playerBitboard = other.playerBitboard.clone();
|
this.playerBitboard = other.playerBitboard.clone();
|
||||||
this.currentTurn = other.currentTurn;
|
this.currentTurn = other.currentTurn;
|
||||||
this.players = Arrays.stream(other.players)
|
this.players = Arrays.stream(other.players)
|
||||||
@@ -61,7 +70,9 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
|||||||
return getCurrentPlayerIndex();
|
return getCurrentPlayerIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Player<T> getPlayer(int index) {return players[index];}
|
public Player<T> getPlayer(int index) {
|
||||||
|
return players[index];
|
||||||
|
}
|
||||||
|
|
||||||
public int getCurrentPlayerIndex() {
|
public int getCurrentPlayerIndex() {
|
||||||
return currentTurn % playerBitboard.length;
|
return currentTurn % playerBitboard.length;
|
||||||
@@ -75,9 +86,17 @@ public abstract class BitboardGame<T extends BitboardGame<T>> implements TurnBas
|
|||||||
return players[getCurrentPlayerIndex()];
|
return players[getCurrentPlayerIndex()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public PlayResult getState() {
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTerminal() {
|
||||||
|
return state.state() == GameState.WIN || state.state() == GameState.DRAW;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long[] getBoard() {return this.playerBitboard;}
|
public long[] getBoard() {return this.playerBitboard;}
|
||||||
|
|
||||||
public void nextTurn() {
|
public void nextTurn() {
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction << 1) & mask;
|
direction |= (direction << 1) & mask;
|
||||||
direction |= (direction << 1) & mask;
|
direction |= (direction << 1) & mask;
|
||||||
|
|
||||||
if (((direction << 1) & player) != 0) {
|
if (((direction << 1) & player & notAFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction >>> 1) & mask;
|
direction |= (direction >>> 1) & mask;
|
||||||
direction |= (direction >>> 1) & mask;
|
direction |= (direction >>> 1) & mask;
|
||||||
|
|
||||||
if (((direction >>> 1) & player) != 0) {
|
if (((direction >>> 1) & player & notHFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction << 9) & mask;
|
direction |= (direction << 9) & mask;
|
||||||
direction |= (direction << 9) & mask;
|
direction |= (direction << 9) & mask;
|
||||||
|
|
||||||
if (((direction << 9) & player) != 0) {
|
if (((direction << 9) & player & notAFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,7 +217,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction << 7) & mask;
|
direction |= (direction << 7) & mask;
|
||||||
direction |= (direction << 7) & mask;
|
direction |= (direction << 7) & mask;
|
||||||
|
|
||||||
if (((direction << 7) & player) != 0) {
|
if (((direction << 7) & player & notHFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction >>> 7) & mask;
|
direction |= (direction >>> 7) & mask;
|
||||||
direction |= (direction >>> 7) & mask;
|
direction |= (direction >>> 7) & mask;
|
||||||
|
|
||||||
if (((direction >>> 7) & player) != 0) {
|
if (((direction >>> 7) & player & notAFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,7 +245,7 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
direction |= (direction >>> 9) & mask;
|
direction |= (direction >>> 9) & mask;
|
||||||
direction |= (direction >>> 9) & mask;
|
direction |= (direction >>> 9) & mask;
|
||||||
|
|
||||||
if (((direction >>> 9) & player) != 0) {
|
if (((direction >>> 9) & player & notHFile) != 0) {
|
||||||
flips |= direction;
|
flips |= direction;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,16 +280,20 @@ public class BitboardReversi extends BitboardGame<BitboardReversi> {
|
|||||||
int winner = getWinner();
|
int winner = getWinner();
|
||||||
|
|
||||||
if (winner == -1) {
|
if (winner == -1) {
|
||||||
return new PlayResult(GameState.DRAW, -1);
|
state = new PlayResult(GameState.DRAW, -1);
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new PlayResult(GameState.WIN, winner);
|
state = new PlayResult(GameState.WIN, winner);
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex());
|
state = new PlayResult(GameState.TURN_SKIPPED, getCurrentPlayerIndex());
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new PlayResult(GameState.NORMAL, getCurrentPlayerIndex());
|
state = new PlayResult(GameState.NORMAL, getCurrentPlayerIndex());
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Score getScore() {
|
public Score getScore() {
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
|
|||||||
public PlayResult play(long move) {
|
public PlayResult play(long move) {
|
||||||
// Player loses if move is invalid
|
// Player loses if move is invalid
|
||||||
if ((move & getLegalMoves()) == 0 || Long.bitCount(move) != 1){
|
if ((move & getLegalMoves()) == 0 || Long.bitCount(move) != 1){
|
||||||
return new PlayResult(GameState.WIN, getNextPlayer());
|
state = new PlayResult(GameState.WIN, getNextPlayer());
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move is legal, make move
|
// Move is legal, make move
|
||||||
@@ -50,7 +51,8 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
|
|||||||
|
|
||||||
// Check if current player won
|
// Check if current player won
|
||||||
if (checkWin(playerBitboard)) {
|
if (checkWin(playerBitboard)) {
|
||||||
return new PlayResult(GameState.WIN, getCurrentPlayerIndex());
|
state = new PlayResult(GameState.WIN, getCurrentPlayerIndex());
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proceed to next turn
|
// Proceed to next turn
|
||||||
@@ -59,11 +61,13 @@ public class BitboardTicTacToe extends BitboardGame<BitboardTicTacToe> {
|
|||||||
|
|
||||||
// Check for early draw
|
// Check for early draw
|
||||||
if (getLegalMoves() == 0L || checkEarlyDraw()) {
|
if (getLegalMoves() == 0L || checkEarlyDraw()) {
|
||||||
return new PlayResult(GameState.DRAW, -1);
|
state = new PlayResult(GameState.DRAW, -1);
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nothing weird happened, continue on as normal
|
// Nothing weird happened, continue on as normal
|
||||||
return new PlayResult(GameState.NORMAL, -1);
|
state = new PlayResult(GameState.NORMAL, -1);
|
||||||
|
return state;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkWin(long board) {
|
private boolean checkWin(long board) {
|
||||||
|
|||||||
@@ -41,15 +41,19 @@ public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
|||||||
return expanded >= children.length;
|
return expanded >= children.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Node bestUCTChild(float explorationFactor) {
|
float calculateUCT() {
|
||||||
|
float exploitation = visits <= 0? 0 : value / visits;
|
||||||
|
float exploration = 1.41f * (float)(Math.sqrt(Math.log(visits) / visits));
|
||||||
|
|
||||||
|
return exploitation + exploration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node bestUCTChild() {
|
||||||
int bestChildIndex = -1;
|
int bestChildIndex = -1;
|
||||||
float bestScore = Float.NEGATIVE_INFINITY;
|
float bestScore = Float.NEGATIVE_INFINITY;
|
||||||
|
|
||||||
for (int i = 0; i < expanded; i++) {
|
for (int i = 0; i < expanded; i++) {
|
||||||
float exploitation = children[i].visits <= 0? 0 : children[i].value / children[i].visits;
|
final float score = calculateUCT();
|
||||||
float exploration = explorationFactor * (float)(Math.sqrt(Math.log(visits) / (children[i].visits + 0.001f)));
|
|
||||||
|
|
||||||
float score = exploitation + exploration;
|
|
||||||
|
|
||||||
if (score > bestScore) {
|
if (score > bestScore) {
|
||||||
bestChildIndex = i;
|
bestChildIndex = i;
|
||||||
@@ -109,14 +113,12 @@ public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Visit count: " + root.visits);
|
|
||||||
|
|
||||||
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
|
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
|
||||||
}
|
}
|
||||||
|
|
||||||
private Node selection(Node node) {
|
private Node selection(Node node) {
|
||||||
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
|
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
|
||||||
node = node.bestUCTChild(1.41f);
|
node = node.bestUCTChild();
|
||||||
}
|
}
|
||||||
|
|
||||||
return node;
|
return node;
|
||||||
|
|||||||
195
game/src/main/java/org/toop/game/players/ai/MCTSAI2.java
Normal file
195
game/src/main/java/org/toop/game/players/ai/MCTSAI2.java
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package org.toop.game.players.ai;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class MCTSAI2<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||||
|
private static class Node {
|
||||||
|
public TurnBasedGame<?> state;
|
||||||
|
|
||||||
|
public long move;
|
||||||
|
public long unexpandedMoves;
|
||||||
|
|
||||||
|
public Node parent;
|
||||||
|
|
||||||
|
public Node[] children;
|
||||||
|
public int expanded;
|
||||||
|
|
||||||
|
public float value;
|
||||||
|
public int visits;
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||||
|
final long legalMoves = state.getLegalMoves();
|
||||||
|
|
||||||
|
this.state = state;
|
||||||
|
|
||||||
|
this.move = move;
|
||||||
|
this.unexpandedMoves = legalMoves;
|
||||||
|
|
||||||
|
this.parent = parent;
|
||||||
|
|
||||||
|
this.children = new Node[Long.bitCount(legalMoves)];
|
||||||
|
this.expanded = 0;
|
||||||
|
|
||||||
|
this.value = 0.0f;
|
||||||
|
this.visits = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state) {
|
||||||
|
this(state, null, 0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isFullyExpanded() {
|
||||||
|
return expanded == children.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float calculateUCT(int parentVisits) {
|
||||||
|
final float exploitation = value / visits;
|
||||||
|
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||||
|
|
||||||
|
return exploitation + exploration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node bestUCTChild() {
|
||||||
|
Node highestUCTChild = null;
|
||||||
|
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
final float childUCT = children[i].calculateUCT(visits);
|
||||||
|
|
||||||
|
if (childUCT > highestUCT) {
|
||||||
|
highestUCTChild = children[i];
|
||||||
|
highestUCT = childUCT;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return highestUCTChild;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Random random;
|
||||||
|
private final int milliseconds;
|
||||||
|
|
||||||
|
public MCTSAI2(int milliseconds) {
|
||||||
|
this.random = new Random();
|
||||||
|
this.milliseconds = milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI2(MCTSAI2<?> other) {
|
||||||
|
this.random = other.random;
|
||||||
|
this.milliseconds = other.milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI2<T> deepCopy() {
|
||||||
|
return new MCTSAI2<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
final Node root = new Node(game, null, 0L);
|
||||||
|
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
while (System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(root);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
|
||||||
|
return mostVisitedChild != null? mostVisitedChild.move : 0L;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node mostVisitedChild(Node root) {
|
||||||
|
Node mostVisitedChild = null;
|
||||||
|
int mostVisited = -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < root.expanded; i++) {
|
||||||
|
if (root.children[i].visits > mostVisited) {
|
||||||
|
mostVisitedChild = root.children[i];
|
||||||
|
mostVisited = root.children[i].visits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mostVisitedChild;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node selection(Node root) {
|
||||||
|
while (root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||||
|
root = root.bestUCTChild();
|
||||||
|
}
|
||||||
|
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node expansion(Node leaf) {
|
||||||
|
if (leaf.unexpandedMoves == 0L) {
|
||||||
|
return leaf;
|
||||||
|
}
|
||||||
|
|
||||||
|
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
||||||
|
|
||||||
|
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||||
|
copiedState.play(unexpandedMove);
|
||||||
|
|
||||||
|
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||||
|
|
||||||
|
leaf.children[leaf.expanded] = expandedChild;
|
||||||
|
leaf.expanded++;
|
||||||
|
|
||||||
|
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||||
|
|
||||||
|
return expandedChild;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float simulation(Node leaf) {
|
||||||
|
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||||
|
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||||
|
|
||||||
|
while (!copiedState.isTerminal()) {
|
||||||
|
final long legalMoves = copiedState.getLegalMoves();
|
||||||
|
final long randomMove = randomSetBit(legalMoves);
|
||||||
|
|
||||||
|
copiedState.play(randomMove);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (copiedState.getWinner() == playerIndex) {
|
||||||
|
return 1.0f;
|
||||||
|
} else if (copiedState.getWinner() >= 0) {
|
||||||
|
return -1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void backPropagation(Node leaf, float value) {
|
||||||
|
while (leaf != null) {
|
||||||
|
leaf.value += value;
|
||||||
|
leaf.visits++;
|
||||||
|
|
||||||
|
value = -value;
|
||||||
|
leaf = leaf.parent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long randomSetBit(long value) {
|
||||||
|
if (0L == value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
final int bitCount = Long.bitCount(value);
|
||||||
|
final int randomBitCount = random.nextInt(bitCount);
|
||||||
|
|
||||||
|
for (int i = 0; i < randomBitCount; i++) {
|
||||||
|
value &= value - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return value & -value;
|
||||||
|
}
|
||||||
|
}
|
||||||
258
game/src/main/java/org/toop/game/players/ai/MCTSAI3.java
Normal file
258
game/src/main/java/org/toop/game/players/ai/MCTSAI3.java
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
package org.toop.game.players.ai;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class MCTSAI3<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||||
|
private static class Node {
|
||||||
|
public TurnBasedGame<?> state;
|
||||||
|
|
||||||
|
public long move;
|
||||||
|
public long unexpandedMoves;
|
||||||
|
|
||||||
|
public Node parent;
|
||||||
|
|
||||||
|
public Node[] children;
|
||||||
|
public int expanded;
|
||||||
|
|
||||||
|
public float value;
|
||||||
|
public int visits;
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state, Node parent, long move) {
|
||||||
|
final long legalMoves = state.getLegalMoves();
|
||||||
|
|
||||||
|
this.state = state;
|
||||||
|
|
||||||
|
this.move = move;
|
||||||
|
this.unexpandedMoves = legalMoves;
|
||||||
|
|
||||||
|
this.parent = parent;
|
||||||
|
|
||||||
|
this.children = new Node[Long.bitCount(legalMoves)];
|
||||||
|
this.expanded = 0;
|
||||||
|
|
||||||
|
this.value = 0.0f;
|
||||||
|
this.visits = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state) {
|
||||||
|
this(state, null, 0L);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isFullyExpanded() {
|
||||||
|
return expanded == children.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float calculateUCT(int parentVisits) {
|
||||||
|
final float exploitation = value / visits;
|
||||||
|
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
|
||||||
|
|
||||||
|
return exploitation + exploration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node bestUCTChild() {
|
||||||
|
Node highestUCTChild = null;
|
||||||
|
float highestUCT = Float.NEGATIVE_INFINITY;
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
final float childUCT = children[i].calculateUCT(visits);
|
||||||
|
|
||||||
|
if (childUCT > highestUCT) {
|
||||||
|
highestUCTChild = children[i];
|
||||||
|
highestUCT = childUCT;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return highestUCTChild;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Random random;
|
||||||
|
|
||||||
|
private Node root;
|
||||||
|
private final int milliseconds;
|
||||||
|
|
||||||
|
public MCTSAI3(int milliseconds) {
|
||||||
|
this.random = new Random();
|
||||||
|
|
||||||
|
this.root = null;
|
||||||
|
this.milliseconds = milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI3(MCTSAI3<?> other) {
|
||||||
|
this.random = other.random;
|
||||||
|
|
||||||
|
this.root = other.root;
|
||||||
|
this.milliseconds = other.milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI3<T> deepCopy() {
|
||||||
|
return new MCTSAI3<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
detectRoot(game);
|
||||||
|
|
||||||
|
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
|
||||||
|
|
||||||
|
while (System.nanoTime() < endTime) {
|
||||||
|
Node leaf = selection(root);
|
||||||
|
leaf = expansion(leaf);
|
||||||
|
final float value = simulation(leaf);
|
||||||
|
backPropagation(leaf, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
final Node mostVisitedChild = mostVisitedChild(root);
|
||||||
|
final long move = mostVisitedChild != null? mostVisitedChild.move : 0L;
|
||||||
|
|
||||||
|
newRoot(move);
|
||||||
|
|
||||||
|
return move;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node mostVisitedChild(Node root) {
|
||||||
|
Node mostVisitedChild = null;
|
||||||
|
int mostVisited = -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < root.expanded; i++) {
|
||||||
|
if (root.children[i].visits > mostVisited) {
|
||||||
|
mostVisitedChild = root.children[i];
|
||||||
|
mostVisited = root.children[i].visits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mostVisitedChild;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void detectRoot(T game) {
|
||||||
|
if (root == null) {
|
||||||
|
root = new Node(game.deepCopy());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final long[] currentBoards = game.getBoard();
|
||||||
|
final long[] rootBoards = root.state.getBoard();
|
||||||
|
|
||||||
|
boolean detected = true;
|
||||||
|
|
||||||
|
for (int i = 0; i < rootBoards.length; i++) {
|
||||||
|
if (rootBoards[i] != currentBoards[i]) {
|
||||||
|
detected = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (detected) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < root.expanded; i++) {
|
||||||
|
final Node child = root.children[i];
|
||||||
|
|
||||||
|
final long[] childBoards = child.state.getBoard();
|
||||||
|
|
||||||
|
detected = true;
|
||||||
|
|
||||||
|
for (int j = 0; j < childBoards.length; j++) {
|
||||||
|
if (childBoards[j] != currentBoards[j]) {
|
||||||
|
detected = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (detected) {
|
||||||
|
root = child;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
root = new Node(game.deepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void newRoot(long move) {
|
||||||
|
for (final Node child : root.children) {
|
||||||
|
if (child.move == move) {
|
||||||
|
root = child;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node selection(Node root) {
|
||||||
|
while (root.isFullyExpanded() && !root.state.isTerminal()) {
|
||||||
|
root = root.bestUCTChild();
|
||||||
|
}
|
||||||
|
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node expansion(Node leaf) {
|
||||||
|
if (leaf.unexpandedMoves == 0L) {
|
||||||
|
return leaf;
|
||||||
|
}
|
||||||
|
|
||||||
|
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
|
||||||
|
|
||||||
|
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||||
|
copiedState.play(unexpandedMove);
|
||||||
|
|
||||||
|
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
|
||||||
|
|
||||||
|
leaf.children[leaf.expanded] = expandedChild;
|
||||||
|
leaf.expanded++;
|
||||||
|
|
||||||
|
leaf.unexpandedMoves &= ~unexpandedMove;
|
||||||
|
|
||||||
|
return expandedChild;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float simulation(Node leaf) {
|
||||||
|
final TurnBasedGame<?> copiedState = leaf.state.deepCopy();
|
||||||
|
final int playerIndex = 1 - copiedState.getCurrentTurn();
|
||||||
|
|
||||||
|
while (!copiedState.isTerminal()) {
|
||||||
|
final long legalMoves = copiedState.getLegalMoves();
|
||||||
|
final long randomMove = randomSetBit(legalMoves);
|
||||||
|
|
||||||
|
copiedState.play(randomMove);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (copiedState.getWinner() == playerIndex) {
|
||||||
|
return 1.0f;
|
||||||
|
} else if (copiedState.getWinner() >= 0) {
|
||||||
|
return -1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void backPropagation(Node leaf, float value) {
|
||||||
|
while (leaf != null) {
|
||||||
|
leaf.value += value;
|
||||||
|
leaf.visits++;
|
||||||
|
|
||||||
|
value = -value;
|
||||||
|
leaf = leaf.parent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private long randomSetBit(long value) {
|
||||||
|
if (0L == value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
final int bitCount = Long.bitCount(value);
|
||||||
|
final int randomBitCount = random.nextInt(bitCount);
|
||||||
|
|
||||||
|
for (int i = 0; i < randomBitCount; i++) {
|
||||||
|
value &= value - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return value & -value;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user