7 Commits

Author SHA1 Message Date
ramollia
3e4a343c4e update mcts 2026-01-16 15:55:25 +01:00
lieght
c4b9378128 Merge remote-tracking branch 'origin/223-create-a-reversi-ai-using-the-mcts-algorithm' into 223-create-a-reversi-ai-using-the-mcts-algorithm
# Conflicts:
#	app/src/main/java/org/toop/Main.java
#	app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java
#	framework/src/main/java/org/toop/framework/game/BitboardGame.java
#	framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java
#	game/src/main/java/org/toop/game/players/ArtificialPlayer.java
#	game/src/main/java/org/toop/game/players/ai/MCTSAI.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java
2026-01-16 14:46:24 +01:00
lieght
95e96583ec Merge changes on dev 2026-01-16 14:45:48 +01:00
ramollia
d02c7bd095 update mcts 2026-01-16 14:37:47 +01:00
ramollia
4e22c01bde mcts v1, v2, v3, v4 done. v5 wip 2026-01-16 14:33:47 +01:00
ramollia
c54b2a19e2 update mcts 2026-01-16 12:10:45 +01:00
ramollia
a6f5f2c854 mcts v1, v2, v3, v4 done. v5 wip 2026-01-15 01:37:33 +01:00
23 changed files with 212 additions and 1367 deletions

View File

@@ -1,9 +1,68 @@
package org.toop; package org.toop;
import org.toop.app.App; import org.toop.app.App;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.game.players.ArtificialPlayer;
import org.toop.game.players.ai.MCTSAI;
import org.toop.game.players.ai.RandomAI;
import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public final class Main { public final class Main {
static void main(String[] args) { static void main(String[] args) {
App.run(args); App.run(args);
// final ExecutorService executor = Executors.newFixedThreadPool(1);
// executor.execute(() -> testAIs(25));
} }
}
private static void testAIs(int games) {
var versions = new ArtificialPlayer[5];
versions[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
versions[1] = new ArtificialPlayer(new MCTSAI1(1000), "MCTS V1 AI");
versions[2] = new ArtificialPlayer(new MCTSAI2(1000), "MCTS V2 AI");
versions[3] = new ArtificialPlayer(new MCTSAI3(10, 10), "MCTS V3 AI");
versions[4] = new ArtificialPlayer(new MCTSAI4(10, 10), "MCTS V4 AI");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
testAIVSAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]});
}
}
}
private static void testAIVSAI(int games, ArtificialPlayer[] ais) {
int wins = 0;
int ties = 0;
for (int i = 0; i < games; i++) {
final BitboardReversi match = new BitboardReversi();
match.init(ais);
while (!match.isTerminal()) {
final int currentAI = match.getCurrentTurn();
final long move = ais[currentAI].getMove(match);
match.play(move);
}
if (match.getWinner() < 0) {
ties++;
continue;
}
wins += match.getWinner() == 0? 1 : 0;
}
System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName());
System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games);
}
}

View File

@@ -210,7 +210,7 @@ public final class Server {
Player[] players = new Player[2]; Player[] players = new Player[2];
players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, 8), user); players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, Runtime.getRuntime().availableProcessors()), user);
players[opponentStartingTurn] = new OnlinePlayer(response.opponent()); players[opponentStartingTurn] = new OnlinePlayer(response.opponent());
switch (type) { switch (type) {

View File

@@ -10,7 +10,6 @@ import org.toop.app.widget.WidgetContainer;
import org.toop.app.widget.view.GameView; import org.toop.app.widget.view.GameView;
import org.toop.framework.eventbus.EventFlow; import org.toop.framework.eventbus.EventFlow;
import org.toop.framework.eventbus.GlobalEventBus; import org.toop.framework.eventbus.GlobalEventBus;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.gameFramework.controller.GameController; import org.toop.framework.gameFramework.controller.GameController;
import org.toop.framework.gameFramework.model.game.threadBehaviour.SupportsOnlinePlay; import org.toop.framework.gameFramework.model.game.threadBehaviour.SupportsOnlinePlay;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
@@ -159,14 +158,8 @@ public class GenericGameController implements GameController {
canvas.redraw(gameCopy); canvas.redraw(gameCopy);
String gameType = game.getClass().getSimpleName().replace("Bitboard",""); String gameType = game.getClass().getSimpleName().replace("Bitboard","");
gameView.nextPlayer(true, getCurrentPlayer().getName(), game.getPlayer(1-getCurrentPlayerIndex()).getName(),gameType); gameView.nextPlayer(true, getCurrentPlayer().getName(), game.getPlayer(1-getCurrentPlayerIndex()).getName(),gameType);
if (gameType.equals("Reversi")) { if (getCurrentPlayer() instanceof LocalPlayer && gameType.equals("Reversi")){
BitboardReversi reversiGame = (BitboardReversi) game; ((ReversiBitCanvas)canvas).drawLegalDots(gameCopy);
BitboardReversi.Score reversiScore = reversiGame.getScore();
gameView.setPlayer1Score(reversiScore.black());
gameView.setPlayer2Score(reversiScore.white());
if (getCurrentPlayer() instanceof LocalPlayer) {
((ReversiBitCanvas)canvas).drawLegalDots(gameCopy);
}
} }
} }
} }

View File

@@ -11,19 +11,13 @@ import org.toop.framework.game.players.OnlinePlayer;
import java.util.Arrays; import java.util.Arrays;
public class ReversiBitController extends GenericGameController { public class ReversiBitController extends GenericGameController {
private BitboardReversi game;
public ReversiBitController(Player[] players) { public ReversiBitController(Player[] players) {
BitboardReversi game = new BitboardReversi(); BitboardReversi game = new BitboardReversi();
game.init(players); game.init(players);
ThreadBehaviour thread = Arrays.stream(players).anyMatch(e -> e instanceof OnlinePlayer) ? ThreadBehaviour thread = Arrays.stream(players).anyMatch(e -> e instanceof OnlinePlayer) ?
new OnlineThreadBehaviour(game) : new LocalThreadBehaviour(game); new OnlineThreadBehaviour(game) : new LocalThreadBehaviour(game);
super(new ReversiBitCanvas(), game, thread, "Reversi"); super(new ReversiBitCanvas(), game, thread, "Reversi");
} }
public BitboardReversi.Score getScore() {
return game.getScore();
}
} }

View File

@@ -4,7 +4,6 @@ import org.toop.app.widget.complex.ConfirmWidget;
import org.toop.app.widget.complex.PopupWidget; import org.toop.app.widget.complex.PopupWidget;
import javafx.geometry.Pos; import javafx.geometry.Pos;
import org.toop.framework.game.games.reversi.BitboardReversi;
public final class GameOverPopup extends PopupWidget { public final class GameOverPopup extends PopupWidget {
public GameOverPopup(boolean winOrTie, String winner) { public GameOverPopup(boolean winOrTie, String winner) {
@@ -16,6 +15,7 @@ public final class GameOverPopup extends PopupWidget {
else{ else{
confirmWidget.setMessage("It was a tie!"); confirmWidget.setMessage("It was a tie!");
} }
confirmWidget.addButton("ok", this::hide); confirmWidget.addButton("ok", this::hide);
add(Pos.CENTER, confirmWidget); add(Pos.CENTER, confirmWidget);

View File

@@ -26,8 +26,6 @@ public final class GameView extends ViewWidget {
private final Text player2Header; private final Text player2Header;
private Circle player1Icon; private Circle player1Icon;
private Circle player2Icon; private Circle player2Icon;
private final Text player1Score;
private final Text player2Score;
private final Button forfeitButton; private final Button forfeitButton;
private final Button exitButton; private final Button exitButton;
private final TextField chatInput; private final TextField chatInput;
@@ -42,8 +40,6 @@ public final class GameView extends ViewWidget {
player2Header = Primitive.header(""); player2Header = Primitive.header("");
player1Icon = new Circle(); player1Icon = new Circle();
player2Icon = new Circle(); player2Icon = new Circle();
player1Score = Primitive.header("");
player2Score = Primitive.header("");
if (onForfeit != null) { if (onForfeit != null) {
forfeitButton = Primitive.button("forfeit", () -> onForfeit.run(), false); forfeitButton = Primitive.button("forfeit", () -> onForfeit.run(), false);
@@ -157,16 +153,14 @@ public final class GameView extends ViewWidget {
private void setPlayerInfoReversi() { private void setPlayerInfoReversi() {
var player1box = Primitive.hbox( var player1box = Primitive.hbox(
player1Icon, player1Icon,
player1Header, player1Header
player1Score
); );
player1box.getStyleClass().add("hboxspacing"); player1box.getStyleClass().add("hboxspacing");
var player2box = Primitive.hbox( var player2box = Primitive.hbox(
player2Icon, player2Icon,
player2Header, player2Header
player2Score
); );
player2box.getStyleClass().add("hboxspacing"); player2box.getStyleClass().add("hboxspacing");
@@ -184,12 +178,4 @@ public final class GameView extends ViewWidget {
player2Icon.setFill(Color.BLACK); player2Icon.setFill(Color.BLACK);
add(Pos.TOP_RIGHT, playerInfo); add(Pos.TOP_RIGHT, playerInfo);
} }
public void setPlayer1Score(int score) {
player1Score.setText("(" + Integer.toString(score) + ")");
}
public void setPlayer2Score(int score) {
player2Score.setText("(" + Integer.toString(score) + ")");
}
} }

View File

@@ -82,13 +82,13 @@ public class LocalMultiplayerView extends ViewWidget {
if (information.players[0].isHuman) { if (information.players[0].isHuman) {
players[0] = new LocalPlayer(information.players[0].name); players[0] = new LocalPlayer(information.players[0].name);
} else { } else {
// players[0] = new ArtificialPlayer(new RandomAI(), "Random AI"); // players[0] = new ArtificialPlayer(new RandomAI<BitboardReversi>(), "Random AI");
players[0] = new ArtificialPlayer(new MCTSAI1(100), "MCTS V1 AI"); players[0] = new ArtificialPlayer(new MCTSAI4(500, 4), "MCTS V4 AI");
} }
if (information.players[1].isHuman) { if (information.players[1].isHuman) {
players[1] = new LocalPlayer(information.players[1].name); players[1] = new LocalPlayer(information.players[1].name);
} else { } else {
players[1] = new ArtificialPlayer(new MCTSAI4(100, 8), "MCTS V4 AI"); players[1] = new ArtificialPlayer(new MCTSAI1(500), "MCTS V1 AI");
} }
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
new ShowEnableTutorialWidget( new ShowEnableTutorialWidget(

View File

@@ -6,6 +6,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.gameFramework.model.player.Player;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
// There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy // There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy
public abstract class BitboardGame implements TurnBasedGame { public abstract class BitboardGame implements TurnBasedGame {
@@ -17,8 +18,8 @@ public abstract class BitboardGame implements TurnBasedGame {
private Player[] players; private Player[] players;
// long is 64 bits. Every game has a limit of 64 cells maximum. // long is 64 bits. Every game has a limit of 64 cells maximum.
protected final long[] playerBitboard; private final long[] playerBitboard;
protected int currentTurn = 0; private int currentTurn = 0;
private final int playerCount; private final int playerCount;
public BitboardGame(int columnSize, int rowSize, int playerCount) { public BitboardGame(int columnSize, int rowSize, int playerCount) {
@@ -74,8 +75,6 @@ public abstract class BitboardGame implements TurnBasedGame {
return playerBitboard.length; return playerBitboard.length;
} }
public int getAmountOfTurns() { return currentTurn; }
public int getCurrentTurn() { public int getCurrentTurn() {
return getCurrentPlayerIndex(); return getCurrentPlayerIndex();
} }

View File

@@ -4,6 +4,7 @@ import org.toop.framework.game.BitboardGame;
import org.toop.framework.gameFramework.GameState; import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.gameFramework.model.player.Player;
import org.toop.framework.game.BitboardGame;
public class BitboardReversi extends BitboardGame { public class BitboardReversi extends BitboardGame {
@@ -367,11 +368,4 @@ public class BitboardReversi extends BitboardGame {
return bestMove; return bestMove;
} }
@Override
public void setFrom(long player1, long player2, int turn) {
this.playerBitboard[0] = player1;
this.playerBitboard[1] = player2;
this.currentTurn = turn;
}
} }

View File

@@ -2,7 +2,6 @@ package org.toop.framework.game.games.tictactoe;
import org.toop.framework.gameFramework.GameState; import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.gameFramework.model.player.Player;
import org.toop.framework.game.BitboardGame; import org.toop.framework.game.BitboardGame;
@@ -121,8 +120,4 @@ public class BitboardTicTacToe extends BitboardGame {
public long heuristicMove(long legalMoves) { public long heuristicMove(long legalMoves) {
return legalMoves; return legalMoves;
} }
@Override
public void setFrom(long player1, long player2, int turn) {
}
} }

View File

@@ -14,9 +14,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
*/ */
public class ArtificialPlayer extends AbstractPlayer { public class ArtificialPlayer extends AbstractPlayer {
/** /** The AI instance used to calculate moves. */
* The AI instance used to calculate moves.
*/
private final AI ai; private final AI ai;
/** /**

View File

@@ -16,6 +16,4 @@ public interface TurnBasedGame extends DeepCopyable<TurnBasedGame> {
float rateMove(long move); float rateMove(long move);
long heuristicMove(long legalMoves); long heuristicMove(long legalMoves);
void setFrom(long player1, long player2, int turn);
} }

View File

@@ -4,12 +4,9 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI; import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random; import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class MCTSAI extends AbstractAI { public abstract class MCTSAI extends AbstractAI {
protected static class Node { protected static class Node {
public static final int VIRTUAL_LOSS = -1;
public TurnBasedGame state; public TurnBasedGame state;
public long move; public long move;
@@ -18,8 +15,8 @@ public abstract class MCTSAI extends AbstractAI {
public Node parent; public Node parent;
public Node[] children; public Node[] children;
public AtomicInteger value; public float value;
public AtomicInteger visits; public int visits;
public float heuristic; public float heuristic;
@@ -36,8 +33,8 @@ public abstract class MCTSAI extends AbstractAI {
this.parent = parent; this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)]; this.children = new Node[Long.bitCount(legalMoves)];
this.value = new AtomicInteger(0); this.value = 0.0f;
this.visits = new AtomicInteger(0); this.visits = 0;
this.heuristic = state.rateMove(move); this.heuristic = state.rateMove(move);
@@ -56,14 +53,14 @@ public abstract class MCTSAI extends AbstractAI {
return unexpandedMoves == 0L; return unexpandedMoves == 0L;
} }
public float calculateUCT(float explorationFactor) { public float calculateUCT(int parentVisits) {
if (visits.get() == 0) { if (visits == 0) {
return Float.POSITIVE_INFINITY; return Float.POSITIVE_INFINITY;
} }
final float exploitation = (float) value.get() / visits.get(); final float exploitation = value / visits;
final float exploration = (float)(Math.sqrt(explorationFactor / visits.get())); final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits));
final float bias = heuristic * 10.0f / (visits.get() + 1); final float bias = heuristic * 10.0f / (visits + 1);
return exploitation + exploration + bias; return exploitation + exploration + bias;
} }
@@ -75,7 +72,7 @@ public abstract class MCTSAI extends AbstractAI {
float highestUCT = Float.NEGATIVE_INFINITY; float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) { for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(2.0f * (float)Math.log(visits.get())); final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) { if (childUCT > highestUCT) {
highestUCTChild = children[i]; highestUCTChild = children[i];
@@ -110,40 +107,33 @@ public abstract class MCTSAI extends AbstractAI {
} }
protected Node selection(Node root) { protected Node selection(Node root) {
while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) { // while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
root.value.addAndGet(Node.VIRTUAL_LOSS); while (root.isFullyExpanded() && !root.state.isTerminal()) {
root.visits.incrementAndGet();
root = root.bestUCTChild(); root = root.bestUCTChild();
} }
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
return root; return root;
} }
protected Node expansion(Node leaf) { protected Node expansion(Node leaf) {
synchronized (leaf) { if (leaf.unexpandedMoves == 0L) {
if (leaf.unexpandedMoves == 0L) { return leaf;
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
} }
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
} }
protected int simulation(Node leaf) { protected float simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy(); final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn(); final int playerIndex = 1 - copiedState.getCurrentTurn();
@@ -155,20 +145,20 @@ public abstract class MCTSAI extends AbstractAI {
} }
if (copiedState.getWinner() == playerIndex) { if (copiedState.getWinner() == playerIndex) {
return 1; return 1.0f;
} }
if (copiedState.getWinner() >= 0) { if (copiedState.getWinner() >= 0) {
return -1; return -1.0f;
} }
return 0; return 0.0f;
} }
protected void backPropagation(Node leaf, int value) { protected void backPropagation(Node leaf, float value) {
while (leaf != null) { while (leaf != null) {
value -= Node.VIRTUAL_LOSS; leaf.value += value;
leaf.value.addAndGet(value); leaf.visits++;
if (Float.isNaN(leaf.solved)) { if (Float.isNaN(leaf.solved)) {
updateSolvedStatus(leaf); updateSolvedStatus(leaf);
@@ -186,9 +176,9 @@ public abstract class MCTSAI extends AbstractAI {
int mostVisited = -1; int mostVisited = -1;
for (int i = 0; i < expanded; i++) { for (int i = 0; i < expanded; i++) {
if (root.children[i].visits.get() > mostVisited) { if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i]; mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits.get(); mostVisited = root.children[i].visits;
} }
} }

View File

@@ -1,195 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
public class MCTSAI2 extends AbstractAI {
private static class Node {
public TurnBasedGame state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public Node(TurnBasedGame state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
}
public Node(TurnBasedGame state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private final Random random;
private final int milliseconds;
public MCTSAI2(int milliseconds) {
this.random = new Random();
this.milliseconds = milliseconds;
}
public MCTSAI2(MCTSAI2 other) {
this.random = other.random;
this.milliseconds = other.milliseconds;
}
@Override
public MCTSAI2 deepCopy() {
return new MCTSAI2(this);
}
@Override
public long getMove(TurnBasedGame game) {
final Node root = new Node(game, null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild != null? mostVisitedChild.move : 0L;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private Node selection(Node root) {
while (root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
} else if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
value = -value;
leaf = leaf.parent;
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -1,258 +0,0 @@
package org.toop.game.players.ai;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
public class MCTSAI3 extends AbstractAI {
private static class Node {
public TurnBasedGame state;
public long move;
public long unexpandedMoves;
public Node parent;
public Node[] children;
public int expanded;
public float value;
public int visits;
public Node(TurnBasedGame state, Node parent, long move) {
final long legalMoves = state.getLegalMoves();
this.state = state;
this.move = move;
this.unexpandedMoves = legalMoves;
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.expanded = 0;
this.value = 0.0f;
this.visits = 0;
}
public Node(TurnBasedGame state) {
this(state, null, 0L);
}
public boolean isFullyExpanded() {
return expanded == children.length;
}
public float calculateUCT(int parentVisits) {
final float exploitation = value / visits;
final float exploration = 1.41f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
return exploitation + exploration;
}
public Node bestUCTChild() {
Node highestUCTChild = null;
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
highestUCT = childUCT;
}
}
return highestUCTChild;
}
}
private final Random random;
private Node root;
private final int milliseconds;
public MCTSAI3(int milliseconds) {
this.random = new Random();
this.root = null;
this.milliseconds = milliseconds;
}
public MCTSAI3(MCTSAI3 other) {
this.random = other.random;
this.root = other.root;
this.milliseconds = other.milliseconds;
}
@Override
public MCTSAI3 deepCopy() {
return new MCTSAI3(this);
}
@Override
public long getMove(TurnBasedGame game) {
detectRoot(game);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild != null? mostVisitedChild.move : 0L;
newRoot(move);
return move;
}
private Node mostVisitedChild(Node root) {
Node mostVisitedChild = null;
int mostVisited = -1;
for (int i = 0; i < root.expanded; i++) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
}
}
return mostVisitedChild;
}
private void detectRoot(TurnBasedGame game) {
if (root == null) {
root = new Node(game.deepCopy());
return;
}
final long[] currentBoards = game.getBoard();
final long[] rootBoards = root.state.getBoard();
boolean detected = true;
for (int i = 0; i < rootBoards.length; i++) {
if (rootBoards[i] != currentBoards[i]) {
detected = false;
break;
}
}
if (detected) {
return;
}
for (int i = 0; i < root.expanded; i++) {
final Node child = root.children[i];
final long[] childBoards = child.state.getBoard();
detected = true;
for (int j = 0; j < childBoards.length; j++) {
if (childBoards[j] != currentBoards[j]) {
detected = false;
break;
}
}
if (detected) {
root = child;
return;
}
}
root = new Node(game.deepCopy());
}
private void newRoot(long move) {
for (final Node child : root.children) {
if (child.move == move) {
root = child;
break;
}
}
}
private Node selection(Node root) {
while (root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
return root;
}
private Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.expanded] = expandedChild;
leaf.expanded++;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
private float simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
while (!copiedState.isTerminal()) {
final long legalMoves = copiedState.getLegalMoves();
final long randomMove = randomSetBit(legalMoves);
copiedState.play(randomMove);
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
} else if (copiedState.getWinner() >= 0) {
return -1.0f;
}
return 0.0f;
}
private void backPropagation(Node leaf, float value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
value = -value;
leaf = leaf.parent;
}
}
private long randomSetBit(long value) {
if (0L == value) {
return 0;
}
final int bitCount = Long.bitCount(value);
final int randomBitCount = random.nextInt(bitCount);
for (int i = 0; i < randomBitCount; i++) {
value &= value - 1;
}
return value & -value;
}
}

View File

@@ -23,14 +23,15 @@ public class MCTSAI1 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final float value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
lastIterations = root.visits.get(); lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move; return mostVisitedChild.move;

View File

@@ -29,14 +29,15 @@ public class MCTSAI2 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { // while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final float value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
lastIterations = root.visits.get(); lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move; final long move = mostVisitedChild.move;

View File

@@ -3,27 +3,26 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch; import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.Future;
public class MCTSAI3 extends MCTSAI { public class MCTSAI3 extends MCTSAI {
private final int threads; private final int threads;
private final ExecutorService threadPool;
public MCTSAI3(int milliseconds, int threads) { public MCTSAI3(int milliseconds, int threads) {
super(milliseconds); super(milliseconds);
this.threads = threads; this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads);
} }
public MCTSAI3(MCTSAI3 other) { public MCTSAI3(MCTSAI3 other) {
super(other); super(other);
this.threads = other.threads; this.threads = other.threads;
this.threadPool = other.threadPool;
} }
@Override @Override
@@ -33,27 +32,53 @@ public class MCTSAI3 extends MCTSAI {
@Override @Override
public long getMove(TurnBasedGame game) { public long getMove(TurnBasedGame game) {
final Node root = new Node(game.deepCopy(), null, 0L); final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads); final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> { tasks.add(() -> {
try { final Node localRoot = new Node(game.deepCopy());
iterate(root, endTime);
} finally { // while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
latch.countDown(); while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
} }
return localRoot;
}); });
} }
try { try {
final long remaining = endTime - System.nanoTime(); final List<Future<Node>> results = pool.invokeAll(tasks);
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move; return mostVisitedChild.move;
@@ -64,13 +89,4 @@ public class MCTSAI3 extends MCTSAI {
return randomSetBit(legalMoves); return randomSetBit(legalMoves);
} }
} }
private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
}
} }

View File

@@ -3,29 +3,29 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch; import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.Future;
public class MCTSAI4 extends MCTSAI { public class MCTSAI4 extends MCTSAI {
private final int threads; private final int threads;
private final ExecutorService threadPool; private final Node[] threadRoots;
private Node root;
public MCTSAI4(int milliseconds, int threads) { public MCTSAI4(int milliseconds, int threads) {
super(milliseconds); super(milliseconds);
this.threads = threads; this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads); this.threadRoots = new Node[threads];
} }
public MCTSAI4(MCTSAI4 other) { public MCTSAI4(MCTSAI4 other) {
super(other); super(other);
this.threads = other.threads; this.threads = other.threads;
this.threadPool = other.threadPool; this.threadRoots = other.threadRoots;
} }
@Override @Override
@@ -35,32 +35,66 @@ public class MCTSAI4 extends MCTSAI {
@Override @Override
public long getMove(TurnBasedGame game) { public long getMove(TurnBasedGame game) {
root = findOrResetRoot(root, game); for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads); final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> { final int threadIndex = i;
try {
iterate(root, endTime); tasks.add(() -> {
} finally { final Node localRoot = threadRoots[threadIndex];
latch.countDown();
// while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
} }
return localRoot;
}); });
} }
try { try {
final long remaining = endTime - System.nanoTime(); final List<Future<Node>> results = pool.invokeAll(tasks);
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move; final long move = mostVisitedChild.move;
root = findChildByMove(root, move); for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move; return move;
} catch (Exception _) { } catch (Exception _) {
@@ -70,13 +104,4 @@ public class MCTSAI4 extends MCTSAI {
return randomSetBit(legalMoves); return randomSetBit(legalMoves);
} }
} }
private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
}
} }

View File

@@ -1,77 +0,0 @@
package research;
public class AIData {
public String AI;
public long gamesPlayed;
public double winrate;
public double averageIterations;
public double averageIterations10;
public double averageIterations20;
public double averageIterations30;
public AIData(String AI, long gamesPlayed, double winrate, double averageIterations, double averageIterations10, double averageIterations20, double averageIterations30) {
this.AI = AI;
this.gamesPlayed = gamesPlayed;
this.winrate = winrate;
this.averageIterations = averageIterations;
this.averageIterations10 = averageIterations10;
this.averageIterations20 = averageIterations20;
this.averageIterations30 = averageIterations30;
}
public String getAI() {
return AI;
}
public void setAI(String AI) {
this.AI = AI;
}
public long getGamesPlayed() {
return gamesPlayed;
}
public void setGamesPlayed(long gamesPlayed) {
this.gamesPlayed = gamesPlayed;
}
public double getWinrate() {
return winrate;
}
public void setWinrate(double winrate) {
this.winrate = winrate;
}
public double getAverageIterations() {
return averageIterations;
}
public void setAverageIterations(double averageIterations) {
this.averageIterations = averageIterations;
}
public double getAverageIterations10() {
return averageIterations10;
}
public void setAverageIterations10(double averageIterations10) {
this.averageIterations10 = averageIterations10;
}
public double getAverageIterations20() {
return averageIterations20;
}
public void setAverageIterations20(double averageIterations20) {
this.averageIterations20 = averageIterations20;
}
public double getAverageIterations30() {
return averageIterations30;
}
public void setAverageIterations30(double averageIterations30) {
this.averageIterations30 = averageIterations30;
}
}

View File

@@ -1,612 +0,0 @@
package research;
import org.apache.maven.surefire.shared.io.FileDeleteStrategy;
import org.junit.jupiter.api.*;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.game.players.ArtificialPlayer;
import org.toop.game.players.ai.MCTSAI;
import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.time.LocalTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
public class AITest {
private static String fileName = "gameData.csv";
private static List<Matchup> matchupList = new ArrayList<Matchup>();
private static List<AIData> dataList = new ArrayList<AIData>();
private static List<GameData> gameDataList = new ArrayList<GameData>();
@BeforeAll
public static void init() {
var versions = new ArtificialPlayer[4];
versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI4(10, 8), "MCTS V4");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
addMatch(versions[playerIndex1], versions[playerIndex2]);
addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
}
}
}
// @BeforeAll
// public static void init() {
//
// var versions = new ArtificialPlayer[11];
// versions[0] = new ArtificialPlayer(new MCTSAI3(10, 1), "MCTS V3T1");
// versions[1] = new ArtificialPlayer(new MCTSAI3(10, 2), "MCTS V3T2");
// versions[2] = new ArtificialPlayer(new MCTSAI3(10, 4), "MCTS V3T4");
// versions[3] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3T8");
// versions[4] = new ArtificialPlayer(new MCTSAI3(10, 16), "MCTS V3T16");
// versions[5] = new ArtificialPlayer(new MCTSAI3(10, 128), "MCTS V3T32");
// versions[6] = new ArtificialPlayer(new MCTSAI3(10, 256), "MCTS V3T64");
// versions[7] = new ArtificialPlayer(new MCTSAI3(10, 128), "MCTS V3T128");
// versions[8] = new ArtificialPlayer(new MCTSAI3(10, 256), "MCTS V3T256");
// versions[9] = new ArtificialPlayer(new MCTSAI3(10, 512), "MCTS V3T512");
// versions[10] = new ArtificialPlayer(new MCTSAI3(10, 1024), "MCTS V3T1024");
//
// for (int i = 0; i < versions.length; i++) {
// for (int j = i + 1; j < versions.length; j++) {
// final int playerIndex1 = i % versions.length;
// final int playerIndex2 = j % versions.length;
// addMatch(versions[playerIndex1], versions[playerIndex2]);
// addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
// }
// }
// }
public static void addMatch(ArtificialPlayer v1, ArtificialPlayer v2) {
matchupList.add(new Matchup(v1, v2));
}
public void addAIData(AIData data) {
dataList.add(data);
}
public void addGameData(GameData data) {
gameDataList.add(data);
}
@Test
public void testAIvsAI() {
while (true) {
for (Matchup m : matchupList) {
playGame(m);
}
}
}
public void playGame(Matchup m) {
long nanocounterAI1 = 0L;
long nanocounterAI2 = 0L;
List<Integer> iterationsAI1 = new ArrayList<>();
List<Integer> iterationsAI2 = new ArrayList<>();
final BitboardReversi match = new BitboardReversi();
ArtificialPlayer[] players = new ArtificialPlayer[2];
players[0] = m.getPlayer1();
players[1] = m.getPlayer2();
match.init(players);
while (!match.isTerminal()) {
final int currentAI = match.getCurrentTurn();
final long startTime = System.nanoTime();
final long move = players[currentAI].getMove(match);
final long endTime = System.nanoTime();
if (players[currentAI].getAi() instanceof MCTSAI) {
final int lastIterations = ((MCTSAI) players[currentAI].getAi()).getLastIterations();
if (currentAI == 0) {
iterationsAI1.add(lastIterations);
nanocounterAI1 += (endTime - startTime);
} else {
iterationsAI2.add(lastIterations);
nanocounterAI2 += (endTime - startTime);
}
}
match.play(move);
}
generateMatchData(m.getPlayer1().getName(), m.getPlayer2().getName(), match, iterationsAI1, iterationsAI2, nanocounterAI1, nanocounterAI2);
}
public void generateMatchData(
String AI1,
String AI2,
BitboardReversi match,
List<Integer> iterationsAI1,
List<Integer> iterationsAI2,
long nanocounterAI1,
long nanocounterAI2
) {
try {
var ai110 = iterationsAI1.subList(0, 10);
var ai120 = iterationsAI1.subList(10, 20);
var ai130 = iterationsAI1.subList(20, iterationsAI1.size());
var ai210 = iterationsAI2.subList(0, 10);
var ai220 = iterationsAI2.subList(10, 20);
var ai230 = iterationsAI2.subList(20, iterationsAI2.size());
writeGamesToCSV(fileName, new GameData(
AI1,
AI2,
getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
iterationsAI1.stream().mapToLong(Integer::longValue).sum(),
ai110.stream().mapToLong(Integer::longValue).sum(),
ai120.stream().mapToLong(Integer::longValue).sum(),
ai130.stream().mapToLong(Integer::longValue).sum(),
iterationsAI1.stream().mapToDouble(Integer::doubleValue).sum() / iterationsAI1.size(),
ai110.stream().mapToDouble(Integer::doubleValue).sum() / ai110.size(),
ai120.stream().mapToDouble(Integer::doubleValue).sum() / ai120.size(),
ai130.stream().mapToDouble(Integer::doubleValue).sum() / ai130.size(),
iterationsAI2.stream().mapToInt(Integer::intValue).sum(),
ai210.stream().mapToLong(Integer::longValue).sum(),
ai220.stream().mapToLong(Integer::longValue).sum(),
ai230.stream().mapToLong(Integer::longValue).sum(),
iterationsAI2.stream().mapToDouble(Integer::doubleValue).sum() / iterationsAI2.size(),
ai210.stream().mapToDouble(Integer::doubleValue).sum() / ai210.size(),
ai220.stream().mapToDouble(Integer::doubleValue).sum() / ai220.size(),
ai230.stream().mapToDouble(Integer::doubleValue).sum() / ai230.size(),
nanocounterAI1,
nanocounterAI2,
LocalTime.now().format(DateTimeFormatter.ofPattern("HH:mm:ss"))
));
} catch (IOException e) {
throw new RuntimeException(e);
} catch (IndexOutOfBoundsException e) {
return;
}
}
public String getWinnerForMatch(String AI1, String AI2, BitboardReversi match) {
if (match.getWinner() == 0) {
return AI1;
}
if (match.getWinner() == 1) {
return AI2;
} else {
return "TIE";
}
}
public void generateData(Matchup matchup, BitboardReversi match, List<Integer> iterationsAI1, List<Integer> iterationsAI2) {
boolean matchup1Found = false;
boolean matchup2Found = false;
for (AIData aiData : dataList) {
if (aiData.getAI().equals(matchup.getPlayer1().getName())) {
matchup1Found = true;
} if (aiData.getAI().equals(matchup.getPlayer2().getName())) {
matchup2Found = true;
}
}
if (!(matchup1Found)) {
addAIData(new AIData(matchup.getPlayer1().getName(), 0, 0, 0, 0, 0, 0));
}
if (!(matchup2Found)) {
addAIData(new AIData(matchup.getPlayer2().getName(), 0, 0, 0, 0, 0, 0));
}
for (AIData aiData : dataList) { // set data for player 1
if (aiData.getAI().equals(matchup.getPlayer1().getName())) {
aiData.setGamesPlayed(aiData.getGamesPlayed() + 1);
aiData.setWinrate(calculateWinrate(0, aiData.getWinrate(), aiData.getGamesPlayed(), match.getWinner()));
aiData.setAverageIterations(calculateAverageIterations(aiData.getAverageIterations(), iterationsAI1));
aiData.setAverageIterations10(calculateAverageIterationsStartEnd(0, 10, aiData.getAverageIterations10(), iterationsAI1));
aiData.setAverageIterations20(calculateAverageIterationsStartEnd(10, 20, aiData.getAverageIterations20(), iterationsAI1));
aiData.setAverageIterations30(calculateAverageIterationsStartEnd(20, iterationsAI1.size(), aiData.getAverageIterations30(), iterationsAI1));
}
}
for (AIData aiData : dataList) {
if (aiData.getAI().equals(matchup.getPlayer2().getName())) {
aiData.setGamesPlayed(aiData.getGamesPlayed() + 1);
aiData.setWinrate(calculateWinrate(1, aiData.getWinrate(), aiData.getGamesPlayed(), match.getWinner()));
aiData.setAverageIterations(calculateAverageIterations(aiData.getAverageIterations(), iterationsAI2));
aiData.setAverageIterations10(calculateAverageIterationsStartEnd(0, 10, aiData.getAverageIterations10(), iterationsAI2));
aiData.setAverageIterations20(calculateAverageIterationsStartEnd(10, 20, aiData.getAverageIterations20(), iterationsAI2));
aiData.setAverageIterations30(calculateAverageIterationsStartEnd(20, iterationsAI2.size(), aiData.getAverageIterations30(), iterationsAI2));
}
}
}
public double calculateWinrate(int player, double winrate, long gamesPlayed, int winner) {
double result;
if (winner == 0 && player == 0 || winner == 1 && player == 1) {
return (winrate * (gamesPlayed - 1) + 1) / gamesPlayed;
} else if (winner == 0 && player == 1 || winner == 1 && player == 0) {
return (winrate * (gamesPlayed - 1) + 0) / gamesPlayed;
}
return (winrate * (gamesPlayed - 1) + 0) / gamesPlayed;
}
public double calculateAverageIterations(double averageIterations, List<Integer> thisGameIterations) {
double thisGameIterationsAverage = 0;
for (int iterations = 0; iterations < thisGameIterations.size(); iterations += 1) {
thisGameIterationsAverage += thisGameIterations.get(iterations);
}
thisGameIterationsAverage /= thisGameIterations.size();
return (averageIterations + thisGameIterationsAverage) / 2;
}
public double calculateAverageIterationsStartEnd(int start, int end, double averageIterations, List<Integer> thisGameIterations) {
double thisGameIterationsAverage = 0;
for (int iterations = start; iterations < end; iterations += 1) {
thisGameIterationsAverage += thisGameIterations.get(iterations);
}
thisGameIterationsAverage /= (end - start);
return (averageIterations + thisGameIterationsAverage) / 2;
}
@AfterAll
public static void writeAfterTests() {
try {
writeAIToCsv("Data.csv", dataList);
} catch (IOException e) {
e.printStackTrace();
}
}
public static void writeGamesToCSV(String filepath, GameData gameData) throws IOException {
try (
final BufferedWriter writer = Files.newBufferedWriter(
Paths.get(filepath),
StandardCharsets.UTF_8,
StandardOpenOption.CREATE,
StandardOpenOption.APPEND
);
final BufferedReader reader = new BufferedReader(new FileReader(filepath))
) {
if (reader.readLine() == null || reader.readLine().isBlank()) {
writer.write("Black,White,Winner,Turns Played,Black total iterations,Black total iterations 0-10,Black total iterations 11-20,Black total iterations 21-30,Black average iterations,Black average iterations 0-10,Black average iterations 11-20,Black average iterations 21-30,White total iterations,White total iterations 0-10,White total iterations 11-20,White total iterations 21-30,White average iterations,White average iterations 0-10,White average iterations 11-20,White average iterations 21-30,Total Time AI1,Total Time AI2,Time");
writer.newLine();
}
writer.write(
gameData.AI1() + "," +
gameData.AI2() + "," +
gameData.winner() + "," +
gameData.turns() + "," +
gameData.AI1totalIterations() + "," +
gameData.AI1totalIterations10() + "," +
gameData.AI1totalIterations20() + "," +
gameData.AI1totalIterations30() + "," +
BigDecimal.valueOf(gameData.AI1averageIterations()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations10()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations20()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations30()).setScale(2, RoundingMode.HALF_EVEN) + "," +
gameData.AI2totalIterations() + "," +
gameData.AI2totalIterations10() + "," +
gameData.AI2totalIterations20() + "," +
gameData.AI2totalIterations30() + "," +
BigDecimal.valueOf(gameData.AI2averageIterations()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations10()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations20()).setScale(2, RoundingMode.HALF_EVEN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations30()).setScale(2, RoundingMode.HALF_EVEN) + "," +
(gameData.nanoAI1() / 1_000_000L) + "," +
(gameData.nanoAI2() / 1_000_000L) + "," +
gameData.time());
writer.newLine();
}
}
public static void writeAIToCsv(String filepath, List<AIData> dataList) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
writer.write("AI Name,Games Played,Winrate,Average Iterations,Average Iterations 0-10, Average Iterations 11-20, Average Iterations 20-30");
writer.newLine();
for (AIData data : dataList) {
writer.write(
data.getAI() + "," +
data.getGamesPlayed() + "," +
data.getWinrate() + "," +
Math.round(data.getAverageIterations()) + "," +
Math.round(data.getAverageIterations10()) + "," +
Math.round(data.getAverageIterations20()) + "," +
Math.round(data.getAverageIterations30()));
writer.newLine();
}
}
}
}
//public class AITest {
// private static int games = 2;
//
// @BeforeAll
// public static void setUp() {
// var versions = new ArtificialPlayer[5];
// versions[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
// versions[1] = new ArtificialPlayer(new MCTSAI1(20), "MCTS V1 AI");
// versions[2] = new ArtificialPlayer(new org.toop.game.players.ai.mcts.MCTSAI2(20), "MCTS V2 AI");
// versions[3] = new ArtificialPlayer(new org.toop.game.players.ai.mcts.MCTSAI3(20, 10), "MCTS V3 AI");
// versions[4] = new ArtificialPlayer(new MCTSAI4(20, 10), "MCTS V4 AI");
//
// for (int i = 0; i < versions.length; i++) {
// for (int j = i + 1; j < versions.length; j++) {
// final int playerIndex1 = i % versions.length;
// final int playerIndex2 = j % versions.length;
// addMatchup(versions[playerIndex1], versions[playerIndex2]);
// }
// }
//
// }
//
// @BeforeEach
// public void setUpEach() {
// matchupList = new ArrayList<>();
// }
//
// @Test
// public void testIterationsInRealGame() {
// for (int i = 0; i < matchups.size(); i++) {
// testAIVSAI(games, getMatchup(i));
// }
// }
//
//
// private void testAIVSAI(int games, ArtificialPlayer[] ais) {
//
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int i = 0; i < games; i++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(ais);
//
// List<Integer> iterations1 = new ArrayList<>();
// List<Integer> iterations2 = new ArrayList<>();
//
// while (!match.isTerminal()) {
// final int currentAI = match.getCurrentTurn();
// final long move = ais[currentAI].getMove(match);
// if (ais[currentAI].getAi() instanceof MCTSAI) {
// final int lastIterations = ((MCTSAI) ais[currentAI].getAi()).getLastIterations();
// if (currentAI == 0) {
// iterations1.add(lastIterations);
// } else if (currentAI == 1) {
// iterations2.add(lastIterations);
// }
// }
// match.play(move);
// }
// int winner = match.getWinner();
// iterations1.addFirst(winner);
//// iterations1.add(-999);
// iterations1.addAll(iterations2);
//
// gamesList.add(iterations1);
// }
// matchupList.add(gamesList);
// }
//
// @Test
// public void testIterationsAtFixedMove() {
// for (ArtificialPlayer[] matchup : matchups) {
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int j = 0; j < games; j++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(matchup);
//
// List<Integer> iterations = new ArrayList<>();
//
// for (Long move : fixedMoveSet) {
// match.play(move);
// if (move == 32L) {
// break;
// }
// }
//// iterations.add(-999);
// var player = matchup[match.getCurrentTurn()];
// for (int k = 0; k < 10; k++) {
// player.getMove(match);
// if (player.getAi() instanceof MCTSAI) {
// iterations.add(((MCTSAI) player.getAi()).getLastIterations());
// }
// }
// gamesList.add(iterations);
// }
// matchupList.add(gamesList);
// }
// }
//
//
// @Test
// public void testIterationsInFixedGame() {
// for (ArtificialPlayer[] matchup : matchups) {
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int j = 0; j < games; j++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(matchup);
//
// List<Integer> iterations = new ArrayList<>();
//
// iterations.add(-999);
//
// for (Long move : fixedMoveSet) {
// var player = matchup[match.getCurrentTurn()];
// player.getMove(match);
// if (player.getAi() instanceof MCTSAI) {
// iterations.add(((MCTSAI) player.getAi()).getLastIterations());
// }
// match.play(move);
// }
//
// gamesList.add(iterations);
// }
// matchupList.add(gamesList);
// }
// }
//
// @AfterEach
// public void tearDown() {
// data.add(matchupList);
// }
//
// @AfterAll
// public static void writeAfterTests() {
// try {
// writeToCsv("Data.csv", data);
// } catch (IOException e) {
//
// }
// }
//
//
// public static void writeToCsv(String filepath, List<List<List<List<Integer>>>> data) throws IOException {
// try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
//
// writer.write("TestID,Matchup,GameNr,Winner");
// for (int i = 0; i < data.size(); i++) {
// writer.write(",Iterations");
// }
//
// writer.newLine();
//
// for (int TestID = 0; TestID < data.size(); TestID++) {
// List<List<List<Integer>>> testCase = data.get(TestID);
//
// for (int matchupNr = 0; matchupNr < testCase.size(); matchupNr++) {
// List<List<Integer>> matchup = testCase.get(matchupNr);
//
// for (int gameNr = 0; gameNr < matchup.size(); gameNr++) {
// List<Integer> game = matchup.get(gameNr);
// writer.write((TestID + 1) + "," + (getMatchupName(matchupNr)) + "," + (gameNr + 1));
// for (int i = 0; i < game.size(); i++) {
// if (i == 0) {
// writer.write("," + getWinnerFromMatchup(game.get(i), matchupNr));
// } else {
// writer.write("," + game.get(i));
// }
// }
// writer.newLine();
// }
// }
// }
// }
//
// }
//
//
// private static final List<List<List<List<Integer>>>> data = new ArrayList<>();
// private List<List<List<Integer>>> matchupList = new ArrayList<>();
// private static final List<String> matchupNames = new ArrayList<>();
// private static final List<ArtificialPlayer[]> matchups = new ArrayList<>();
//
// private static String getMatchupName(int matchupNr) {
// return matchupNames.get(matchupNr);
// }
//
// private static ArtificialPlayer[] getMatchup(int matchupNr) {
// return matchups.get(matchupNr);
// }
//
// private static String getWinnerFromMatchup(Integer winner, int matchupNr) {
// String matchup = matchupNames.get(matchupNr);
//
// String[] parts = matchup.split(" vs ");
//
// if (parts.length != 2) {
// return "Invalid matchup formatting.";
// }
//
// return winner == 0 ? parts[0] : winner == 1 ? parts[1] : winner == -999 ? "NVT" : "Tie";
// }
//
// private static void addMatchup(ArtificialPlayer player1, ArtificialPlayer player2) {
// matchups.add(new ArtificialPlayer[]{player1, player2});
// matchupNames.add(player1.getName() + " vs " + player2.getName());
// }
//}
// private final Long[] fixedMoveSet = new Long[]{17592186044416L,
// 35184372088832L,
// 67108864L,
// 8796093022208L,
// 2251799813685248L,
// 288230376151711744L,
// 70368744177664L,
// 1125899906842624L,
// 137438953472L,
// 140737488355328L,
// 4503599627370496L,
// 2305843009213693952L,
// 18014398509481984L,
// 274877906944L,
// 576460752303423488L,
// -9223372036854775808L,
// 549755813888L,
// 1152921504606846976L,
// 144115188075855872L,
// 72057594037927936L,
// 36028797018963968L,
// 17179869184L,
// 2199023255552L,
// 1048576L,
// 4398046511104L,
// 281474976710656L,
// 9007199254740992L,
// 2147483648L,
// 1073741824L,
// 33554432L,
// 262144L,
// 8388608L,
// 8192L,
// 4611686018427387904L,
// 4294967296L,
// 524288L,
// 4096L,
// 16777216L,
// 65536L,
// 32L,
// 2048L,
// 8L,
// 4L,
// 8589934592L,
// 16L,
// 2097152L,
// 4194304L,
// 1024L,
// 512L,
// 16384L,
// 536870912L,
// 1099511627776L,
// 64L,
// 562949953421312L,
// 128L,
// 1L,
// 32768L,
// 2L,
// 256L,
// 131072L};
// }

View File

@@ -1,32 +0,0 @@
package research;
public record GameData(
String AI1,
String AI2,
String winner,
int turns,
long AI1totalIterations,
long AI1totalIterations10,
long AI1totalIterations20,
long AI1totalIterations30,
double AI1averageIterations,
double AI1averageIterations10,
double AI1averageIterations20,
double AI1averageIterations30,
long AI2totalIterations,
long AI2totalIterations10,
long AI2totalIterations20,
long AI2totalIterations30,
double AI2averageIterations,
double AI2averageIterations10,
double AI2averageIterations20,
double AI2averageIterations30,
long nanoAI1,
long nanoAI2,
String time
) {}

View File

@@ -1,30 +0,0 @@
package research;
import org.toop.framework.game.players.ArtificialPlayer;
import java.util.ArrayList;
import java.util.List;
public class Matchup {
public ArtificialPlayer player1;
public ArtificialPlayer player2;
public Matchup(ArtificialPlayer player1, ArtificialPlayer player2) {
this.player1 = player1;
this.player2 = player2;
}
public Matchup() {}
public String toString() {
return player1.toString() + " VS " + player2.toString();
}
public ArtificialPlayer getPlayer1() {
return player1;
}
public ArtificialPlayer getPlayer2() {
return player2;
}
}