7 Commits

Author SHA1 Message Date
ramollia
3e4a343c4e update mcts 2026-01-16 15:55:25 +01:00
lieght
c4b9378128 Merge remote-tracking branch 'origin/223-create-a-reversi-ai-using-the-mcts-algorithm' into 223-create-a-reversi-ai-using-the-mcts-algorithm
# Conflicts:
#	app/src/main/java/org/toop/Main.java
#	app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java
#	framework/src/main/java/org/toop/framework/game/BitboardGame.java
#	framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java
#	game/src/main/java/org/toop/game/players/ArtificialPlayer.java
#	game/src/main/java/org/toop/game/players/ai/MCTSAI.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java
2026-01-16 14:46:24 +01:00
lieght
95e96583ec Merge changes on dev 2026-01-16 14:45:48 +01:00
ramollia
d02c7bd095 update mcts 2026-01-16 14:37:47 +01:00
ramollia
4e22c01bde mcts v1, v2, v3, v4 done. v5 wip 2026-01-16 14:33:47 +01:00
ramollia
c54b2a19e2 update mcts 2026-01-16 12:10:45 +01:00
ramollia
a6f5f2c854 mcts v1, v2, v3, v4 done. v5 wip 2026-01-15 01:37:33 +01:00
21 changed files with 214 additions and 768 deletions

View File

@@ -20,4 +20,49 @@ public final class Main {
// final ExecutorService executor = Executors.newFixedThreadPool(1);
// executor.execute(() -> testAIs(25));
}
private static void testAIs(int games) {
var versions = new ArtificialPlayer[5];
versions[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
versions[1] = new ArtificialPlayer(new MCTSAI1(1000), "MCTS V1 AI");
versions[2] = new ArtificialPlayer(new MCTSAI2(1000), "MCTS V2 AI");
versions[3] = new ArtificialPlayer(new MCTSAI3(10, 10), "MCTS V3 AI");
versions[4] = new ArtificialPlayer(new MCTSAI4(10, 10), "MCTS V4 AI");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
testAIVSAI(games, new ArtificialPlayer[] { versions[playerIndex1], versions[playerIndex2]});
}
}
}
private static void testAIVSAI(int games, ArtificialPlayer[] ais) {
int wins = 0;
int ties = 0;
for (int i = 0; i < games; i++) {
final BitboardReversi match = new BitboardReversi();
match.init(ais);
while (!match.isTerminal()) {
final int currentAI = match.getCurrentTurn();
final long move = ais[currentAI].getMove(match);
match.play(move);
}
if (match.getWinner() < 0) {
ties++;
continue;
}
wins += match.getWinner() == 0? 1 : 0;
}
System.out.printf("Out of %d games, %s won %d -- tied %d -- lost %d, games against %s\n", games, ais[0].getName(), wins, ties, games - wins - ties, ais[1].getName());
System.out.printf("Average win rate was: %.2f\n\n", wins / (float)games);
}
}

View File

@@ -210,7 +210,7 @@ public final class Server {
Player[] players = new Player[2];
players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000), user);
players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, Runtime.getRuntime().availableProcessors()), user);
players[opponentStartingTurn] = new OnlinePlayer(response.opponent());
switch (type) {

View File

@@ -10,7 +10,6 @@ import org.toop.app.widget.WidgetContainer;
import org.toop.app.widget.view.GameView;
import org.toop.framework.eventbus.EventFlow;
import org.toop.framework.eventbus.GlobalEventBus;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.gameFramework.controller.GameController;
import org.toop.framework.gameFramework.model.game.threadBehaviour.SupportsOnlinePlay;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
@@ -159,14 +158,8 @@ public class GenericGameController implements GameController {
canvas.redraw(gameCopy);
String gameType = game.getClass().getSimpleName().replace("Bitboard","");
gameView.nextPlayer(true, getCurrentPlayer().getName(), game.getPlayer(1-getCurrentPlayerIndex()).getName(),gameType);
if (gameType.equals("Reversi")) {
BitboardReversi reversiGame = (BitboardReversi) game;
BitboardReversi.Score reversiScore = reversiGame.getScore();
gameView.setPlayer1Score(reversiScore.black());
gameView.setPlayer2Score(reversiScore.white());
if (getCurrentPlayer() instanceof LocalPlayer) {
if (getCurrentPlayer() instanceof LocalPlayer && gameType.equals("Reversi")){
((ReversiBitCanvas)canvas).drawLegalDots(gameCopy);
}
}
}
}

View File

@@ -11,19 +11,13 @@ import org.toop.framework.game.players.OnlinePlayer;
import java.util.Arrays;
public class ReversiBitController extends GenericGameController {
private BitboardReversi game;
public ReversiBitController(Player[] players) {
BitboardReversi game = new BitboardReversi();
game.init(players);
ThreadBehaviour thread = Arrays.stream(players).anyMatch(e -> e instanceof OnlinePlayer) ?
new OnlineThreadBehaviour(game) : new LocalThreadBehaviour(game);
super(new ReversiBitCanvas(), game, thread, "Reversi");
}
public BitboardReversi.Score getScore() {
return game.getScore();
}
}

View File

@@ -4,7 +4,6 @@ import org.toop.app.widget.complex.ConfirmWidget;
import org.toop.app.widget.complex.PopupWidget;
import javafx.geometry.Pos;
import org.toop.framework.game.games.reversi.BitboardReversi;
public final class GameOverPopup extends PopupWidget {
public GameOverPopup(boolean winOrTie, String winner) {
@@ -16,6 +15,7 @@ public final class GameOverPopup extends PopupWidget {
else{
confirmWidget.setMessage("It was a tie!");
}
confirmWidget.addButton("ok", this::hide);
add(Pos.CENTER, confirmWidget);

View File

@@ -26,8 +26,6 @@ public final class GameView extends ViewWidget {
private final Text player2Header;
private Circle player1Icon;
private Circle player2Icon;
private final Text player1Score;
private final Text player2Score;
private final Button forfeitButton;
private final Button exitButton;
private final TextField chatInput;
@@ -42,8 +40,6 @@ public final class GameView extends ViewWidget {
player2Header = Primitive.header("");
player1Icon = new Circle();
player2Icon = new Circle();
player1Score = Primitive.header("");
player2Score = Primitive.header("");
if (onForfeit != null) {
forfeitButton = Primitive.button("forfeit", () -> onForfeit.run(), false);
@@ -157,16 +153,14 @@ public final class GameView extends ViewWidget {
private void setPlayerInfoReversi() {
var player1box = Primitive.hbox(
player1Icon,
player1Header,
player1Score
player1Header
);
player1box.getStyleClass().add("hboxspacing");
var player2box = Primitive.hbox(
player2Icon,
player2Header,
player2Score
player2Header
);
player2box.getStyleClass().add("hboxspacing");
@@ -184,12 +178,4 @@ public final class GameView extends ViewWidget {
player2Icon.setFill(Color.BLACK);
add(Pos.TOP_RIGHT, playerInfo);
}
public void setPlayer1Score(int score) {
player1Score.setText("(" + Integer.toString(score) + ")");
}
public void setPlayer2Score(int score) {
player2Score.setText("(" + Integer.toString(score) + ")");
}
}

View File

@@ -82,13 +82,13 @@ public class LocalMultiplayerView extends ViewWidget {
if (information.players[0].isHuman) {
players[0] = new LocalPlayer(information.players[0].name);
} else {
// players[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
players[0] = new ArtificialPlayer(new MCTSAI1(100), "MCTS V1 AI");
// players[0] = new ArtificialPlayer(new RandomAI<BitboardReversi>(), "Random AI");
players[0] = new ArtificialPlayer(new MCTSAI4(500, 4), "MCTS V4 AI");
}
if (information.players[1].isHuman) {
players[1] = new LocalPlayer(information.players[1].name);
} else {
players[1] = new ArtificialPlayer(new MCTSAI4(100), "MCTS V4 AI");
players[1] = new ArtificialPlayer(new MCTSAI1(500), "MCTS V1 AI");
}
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
new ShowEnableTutorialWidget(

View File

@@ -6,6 +6,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
// There is AI performance to be gained by getting rid of non-primitives and thus speeding up deepCopy
public abstract class BitboardGame implements TurnBasedGame {
@@ -17,8 +18,8 @@ public abstract class BitboardGame implements TurnBasedGame {
private Player[] players;
// long is 64 bits. Every game has a limit of 64 cells maximum.
protected final long[] playerBitboard;
protected int currentTurn = 0;
private final long[] playerBitboard;
private int currentTurn = 0;
private final int playerCount;
public BitboardGame(int columnSize, int rowSize, int playerCount) {
@@ -74,8 +75,6 @@ public abstract class BitboardGame implements TurnBasedGame {
return playerBitboard.length;
}
public int getAmountOfTurns() { return currentTurn; }
public int getCurrentTurn() {
return getCurrentPlayerIndex();
}

View File

@@ -4,6 +4,7 @@ import org.toop.framework.game.BitboardGame;
import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.player.Player;
import org.toop.framework.game.BitboardGame;
public class BitboardReversi extends BitboardGame {
@@ -367,11 +368,4 @@ public class BitboardReversi extends BitboardGame {
return bestMove;
}
@Override
public void setFrom(long player1, long player2, int turn) {
this.playerBitboard[0] = player1;
this.playerBitboard[1] = player2;
this.currentTurn = turn;
}
}

View File

@@ -2,7 +2,6 @@ package org.toop.framework.game.games.tictactoe;
import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player;
import org.toop.framework.game.BitboardGame;
@@ -121,8 +120,4 @@ public class BitboardTicTacToe extends BitboardGame {
public long heuristicMove(long legalMoves) {
return legalMoves;
}
@Override
public void setFrom(long player1, long player2, int turn) {
}
}

View File

@@ -14,9 +14,7 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
*/
public class ArtificialPlayer extends AbstractPlayer {
/**
* The AI instance used to calculate moves.
*/
/** The AI instance used to calculate moves. */
private final AI ai;
/**

View File

@@ -16,6 +16,4 @@ public interface TurnBasedGame extends DeepCopyable<TurnBasedGame> {
float rateMove(long move);
long heuristicMove(long legalMoves);
void setFrom(long player1, long player2, int turn);
}

View File

@@ -4,12 +4,9 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class MCTSAI extends AbstractAI {
protected static class Node {
public static final int VIRTUAL_LOSS = -1;
public TurnBasedGame state;
public long move;
@@ -18,8 +15,8 @@ public abstract class MCTSAI extends AbstractAI {
public Node parent;
public Node[] children;
public AtomicInteger value;
public AtomicInteger visits;
public float value;
public int visits;
public float heuristic;
@@ -36,8 +33,8 @@ public abstract class MCTSAI extends AbstractAI {
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.value = new AtomicInteger(0);
this.visits = new AtomicInteger(0);
this.value = 0.0f;
this.visits = 0;
this.heuristic = state.rateMove(move);
@@ -56,14 +53,14 @@ public abstract class MCTSAI extends AbstractAI {
return unexpandedMoves == 0L;
}
public float calculateUCT(float explorationFactor) {
if (visits.get() == 0) {
public float calculateUCT(int parentVisits) {
if (visits == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = (float) value.get() / visits.get();
final float exploration = (float)(Math.sqrt(explorationFactor / visits.get()));
final float bias = heuristic * 10.0f / (visits.get() + 1);
final float exploitation = value / visits;
final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits));
final float bias = heuristic * 10.0f / (visits + 1);
return exploitation + exploration + bias;
}
@@ -75,7 +72,7 @@ public abstract class MCTSAI extends AbstractAI {
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(2.0f * (float)Math.log(visits.get()));
final float childUCT = children[i].calculateUCT(visits);
if (childUCT > highestUCT) {
highestUCTChild = children[i];
@@ -110,21 +107,15 @@ public abstract class MCTSAI extends AbstractAI {
}
protected Node selection(Node root) {
while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
// while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
while (root.isFullyExpanded() && !root.state.isTerminal()) {
root = root.bestUCTChild();
}
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
return root;
}
protected Node expansion(Node leaf) {
synchronized (leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
@@ -141,9 +132,8 @@ public abstract class MCTSAI extends AbstractAI {
return expandedChild;
}
}
protected int simulation(Node leaf) {
protected float simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
@@ -155,20 +145,20 @@ public abstract class MCTSAI extends AbstractAI {
}
if (copiedState.getWinner() == playerIndex) {
return 1;
return 1.0f;
}
if (copiedState.getWinner() >= 0) {
return -1;
return -1.0f;
}
return 0;
return 0.0f;
}
protected void backPropagation(Node leaf, int value) {
protected void backPropagation(Node leaf, float value) {
while (leaf != null) {
value -= Node.VIRTUAL_LOSS;
leaf.value.addAndGet(value);
leaf.value += value;
leaf.visits++;
if (Float.isNaN(leaf.solved)) {
updateSolvedStatus(leaf);
@@ -186,9 +176,9 @@ public abstract class MCTSAI extends AbstractAI {
int mostVisited = -1;
for (int i = 0; i < expanded; i++) {
if (root.children[i].visits.get() > mostVisited) {
if (root.children[i].visits > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits.get();
mostVisited = root.children[i].visits;
}
}

View File

@@ -23,14 +23,15 @@ public class MCTSAI1 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits.get();
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;

View File

@@ -29,14 +29,15 @@ public class MCTSAI2 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
// while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits.get();
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;

View File

@@ -3,21 +3,26 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Future;
public class MCTSAI3 extends MCTSAI {
private static final int THREADS = 8;
private final int threads;
private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
public MCTSAI3(int milliseconds) {
public MCTSAI3(int milliseconds, int threads) {
super(milliseconds);
this.threads = threads;
}
public MCTSAI3(MCTSAI3 other) {
super(other);
this.threads = other.threads;
}
@Override
@@ -27,18 +32,53 @@ public class MCTSAI3 extends MCTSAI {
@Override
public long getMove(TurnBasedGame game) {
final Node root = new Node(game.deepCopy(), null, 0L);
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
tasks.add(() -> {
final Node localRoot = new Node(game.deepCopy());
// while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
final List<Future<Node>> results = pool.invokeAll(tasks);
lastIterations = root.visits.get();
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
@@ -49,15 +89,4 @@ public class MCTSAI3 extends MCTSAI {
return randomSetBit(legalMoves);
}
}
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}

View File

@@ -3,27 +3,29 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Future;
public class MCTSAI4 extends MCTSAI {
private static final int THREADS = Runtime.getRuntime().availableProcessors();
private final int threads;
private final Node[] threadRoots;
private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
private Node root;
public MCTSAI4(int milliseconds) {
public MCTSAI4(int milliseconds, int threads) {
super(milliseconds);
this.root = null;
this.threads = threads;
this.threadRoots = new Node[threads];
}
public MCTSAI4(MCTSAI4 other) {
super(other);
this.root = other.root;
this.threads = other.threads;
this.threadRoots = other.threadRoots;
}
@Override
@@ -33,23 +35,66 @@ public class MCTSAI4 extends MCTSAI {
@Override
public long getMove(TurnBasedGame game) {
root = findOrResetRoot(root, game);
for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
// while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
while (System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
}
try {
threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
final List<Future<Node>> results = pool.invokeAll(tasks);
lastIterations = root.visits.get();
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
root = findChildByMove(root, move);
for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move;
} catch (Exception _) {
@@ -59,15 +104,4 @@ public class MCTSAI4 extends MCTSAI {
return randomSetBit(legalMoves);
}
}
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}

View File

@@ -1,77 +0,0 @@
package research;
public class AIData {
public String AI;
public long gamesPlayed;
public double winrate;
public double averageIterations;
public double averageIterations10;
public double averageIterations20;
public double averageIterations30;
public AIData(String AI, long gamesPlayed, double winrate, double averageIterations, double averageIterations10, double averageIterations20, double averageIterations30) {
this.AI = AI;
this.gamesPlayed = gamesPlayed;
this.winrate = winrate;
this.averageIterations = averageIterations;
this.averageIterations10 = averageIterations10;
this.averageIterations20 = averageIterations20;
this.averageIterations30 = averageIterations30;
}
public String getAI() {
return AI;
}
public void setAI(String AI) {
this.AI = AI;
}
public long getGamesPlayed() {
return gamesPlayed;
}
public void setGamesPlayed(long gamesPlayed) {
this.gamesPlayed = gamesPlayed;
}
public double getWinrate() {
return winrate;
}
public void setWinrate(double winrate) {
this.winrate = winrate;
}
public double getAverageIterations() {
return averageIterations;
}
public void setAverageIterations(double averageIterations) {
this.averageIterations = averageIterations;
}
public double getAverageIterations10() {
return averageIterations10;
}
public void setAverageIterations10(double averageIterations10) {
this.averageIterations10 = averageIterations10;
}
public double getAverageIterations20() {
return averageIterations20;
}
public void setAverageIterations20(double averageIterations20) {
this.averageIterations20 = averageIterations20;
}
public double getAverageIterations30() {
return averageIterations30;
}
public void setAverageIterations30(double averageIterations30) {
this.averageIterations30 = averageIterations30;
}
}

View File

@@ -1,500 +0,0 @@
package research;
import org.junit.jupiter.api.*;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.game.players.ArtificialPlayer;
import org.toop.game.players.ai.MCTSAI;
import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class AITest {
private static List<Matchup> matchupList = new ArrayList<Matchup>();
private static List<AIData> dataList = new ArrayList<AIData>();
private static List<GameData> gameDataList = new ArrayList<GameData>();
@BeforeAll
public static void init() {
var versions = new ArtificialPlayer[4];
versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
addMatch(versions[playerIndex1], versions[playerIndex2]);
addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
}
}
}
public static void addMatch(ArtificialPlayer v1, ArtificialPlayer v2) {
matchupList.add(new Matchup(v1, v2));
}
public void addAIData(AIData data) {
dataList.add(data);
}
public void addGameData(GameData data) {
gameDataList.add(data);
}
@Test
public void testAIvsAI() {
for (Matchup m : matchupList) {
for (int i = 0; i < 6; i++) {
playGame(m);
}
}
}
public void playGame(Matchup m) {
long millisecondscounterAI1 = 0L;
long millisecondscounterAI2 = 0L;
List<Integer> iterationsAI1 = new ArrayList<>();
List<Integer> iterationsAI2 = new ArrayList<>();
final BitboardReversi match = new BitboardReversi();
ArtificialPlayer[] players = new ArtificialPlayer[2];
players[0] = m.getPlayer1();
players[1] = m.getPlayer2();
match.init(players);
while (!match.isTerminal()) {
final int currentAI = match.getCurrentTurn();
final long startTime = System.nanoTime();
final long move = players[currentAI].getMove(match);
final long endTime = System.nanoTime();
if (players[currentAI].getAi() instanceof MCTSAI) {
final int lastIterations = ((MCTSAI) players[currentAI].getAi()).getLastIterations();
if (currentAI == 0) {
iterationsAI1.add(lastIterations);
millisecondscounterAI1 += (endTime - startTime);
} else {
iterationsAI2.add(lastIterations);
millisecondscounterAI2 += (endTime - startTime);
}
}
match.play(move);
}
generateMatchData(m.getPlayer1().getName(), m.getPlayer2().getName(), match, millisecondscounterAI1, millisecondscounterAI2);
generateData(m, match, iterationsAI1, iterationsAI2);
}
public void generateMatchData(String AI1, String AI2, BitboardReversi match, long millisecondscounterAI1, long millisecondscounterAI2) {
addGameData(new GameData(
AI1,
AI2,
getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
millisecondscounterAI1,
millisecondscounterAI2
));
}
public String getWinnerForMatch(String AI1, String AI2, BitboardReversi match) {
if (match.getWinner() == 0) {
return AI1;
}
if (match.getWinner() == 1) {
return AI2;
} else {
return "TIE";
}
}
public void generateData(Matchup matchup, BitboardReversi match, List<Integer> iterationsAI1, List<Integer> iterationsAI2) {
boolean matchup1Found = false;
boolean matchup2Found = false;
for (AIData aiData : dataList) {
if (aiData.getAI().equals(matchup.getPlayer1().getName())) {
matchup1Found = true;
} if (aiData.getAI().equals(matchup.getPlayer2().getName())) {
matchup2Found = true;
}
}
if (!(matchup1Found)) {
addAIData(new AIData(matchup.getPlayer1().getName(), 0, 0, 0, 0, 0, 0));
}
if (!(matchup2Found)) {
addAIData(new AIData(matchup.getPlayer2().getName(), 0, 0, 0, 0, 0, 0));
}
for (AIData aiData : dataList) { // set data for player 1
if (aiData.getAI().equals(matchup.getPlayer1().getName())) {
aiData.setGamesPlayed(aiData.getGamesPlayed() + 1);
aiData.setWinrate(calculateWinrate(0, aiData.getWinrate(), aiData.getGamesPlayed(), match.getWinner()));
aiData.setAverageIterations(calculateAverageIterations(aiData.getAverageIterations(), iterationsAI1));
aiData.setAverageIterations10(calculateAverageIterationsStartEnd(0, 10, aiData.getAverageIterations10(), iterationsAI1));
aiData.setAverageIterations20(calculateAverageIterationsStartEnd(10, 20, aiData.getAverageIterations20(), iterationsAI1));
aiData.setAverageIterations30(calculateAverageIterationsStartEnd(20, iterationsAI1.size(), aiData.getAverageIterations30(), iterationsAI1));
}
}
for (AIData aiData : dataList) {
if (aiData.getAI().equals(matchup.getPlayer2().getName())) {
aiData.setGamesPlayed(aiData.getGamesPlayed() + 1);
aiData.setWinrate(calculateWinrate(1, aiData.getWinrate(), aiData.getGamesPlayed(), match.getWinner()));
aiData.setAverageIterations(calculateAverageIterations(aiData.getAverageIterations(), iterationsAI2));
aiData.setAverageIterations10(calculateAverageIterationsStartEnd(0, 10, aiData.getAverageIterations10(), iterationsAI2));
aiData.setAverageIterations20(calculateAverageIterationsStartEnd(10, 20, aiData.getAverageIterations20(), iterationsAI2));
aiData.setAverageIterations30(calculateAverageIterationsStartEnd(20, iterationsAI2.size(), aiData.getAverageIterations30(), iterationsAI2));
}
}
}
public double calculateWinrate(int player, double winrate, long gamesPlayed, int winner) {
double result;
if (winner == 0 && player == 0 || winner == 1 && player == 1) {
return (winrate * (gamesPlayed - 1) + 1) / gamesPlayed;
} else if (winner == 0 && player == 1 || winner == 1 && player == 0) {
return (winrate * (gamesPlayed - 1) + 0) / gamesPlayed;
}
return (winrate * (gamesPlayed - 1) + 0) / gamesPlayed;
}
public double calculateAverageIterations(double averageIterations, List<Integer> thisGameIterations) {
double thisGameIterationsAverage = 0;
for (int iterations = 0; iterations < thisGameIterations.size(); iterations += 1) {
thisGameIterationsAverage += thisGameIterations.get(iterations);
}
thisGameIterationsAverage /= thisGameIterations.size();
return (averageIterations + thisGameIterationsAverage) / 2;
}
public double calculateAverageIterationsStartEnd(int start, int end, double averageIterations, List<Integer> thisGameIterations) {
double thisGameIterationsAverage = 0;
for (int iterations = start; iterations < end; iterations += 1) {
thisGameIterationsAverage += thisGameIterations.get(iterations);
}
thisGameIterationsAverage /= (end - start);
return (averageIterations + thisGameIterationsAverage) / 2;
}
@AfterAll
public static void writeAfterTests() {
try {
writeAIToCsv("Data.csv", dataList);
writeGamesToCSV("Games.csv", gameDataList);
} catch (IOException e) {
e.printStackTrace();
}
}
public static void writeGamesToCSV(String filepath, List<GameData> gameDataList) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
writer.write("Black,White,Winner,Turns Played,Total Time AI1, Total Time AI2");
writer.newLine();
for (GameData gameData : gameDataList) {
writer.write(
gameData.AI1() + "," +
gameData.AI2() + "," +
gameData.winner() + "," +
gameData.turns() + "," +
(gameData.millisecondsAI1() / 1_000_000L) + "," +
(gameData.millisecondsAI2() / 1_000_000L));
writer.newLine();
}
}
}
public static void writeAIToCsv(String filepath, List<AIData> dataList) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
writer.write("AI Name,Games Played,Winrate,Average Iterations,Average Iterations 0-10, Average Iterations 11-20, Average Iterations 20-30");
writer.newLine();
for (AIData data : dataList) {
writer.write(
data.getAI() + "," +
data.getGamesPlayed() + "," +
data.getWinrate() + "," +
Math.round(data.getAverageIterations()) + "," +
Math.round(data.getAverageIterations10()) + "," +
Math.round(data.getAverageIterations20()) + "," +
Math.round(data.getAverageIterations30()));
writer.newLine();
}
}
}
}
//public class AITest {
// private static int games = 2;
//
// @BeforeAll
// public static void setUp() {
// var versions = new ArtificialPlayer[5];
// versions[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
// versions[1] = new ArtificialPlayer(new MCTSAI1(20), "MCTS V1 AI");
// versions[2] = new ArtificialPlayer(new org.toop.game.players.ai.mcts.MCTSAI2(20), "MCTS V2 AI");
// versions[3] = new ArtificialPlayer(new org.toop.game.players.ai.mcts.MCTSAI3(20, 10), "MCTS V3 AI");
// versions[4] = new ArtificialPlayer(new MCTSAI4(20, 10), "MCTS V4 AI");
//
// for (int i = 0; i < versions.length; i++) {
// for (int j = i + 1; j < versions.length; j++) {
// final int playerIndex1 = i % versions.length;
// final int playerIndex2 = j % versions.length;
// addMatchup(versions[playerIndex1], versions[playerIndex2]);
// }
// }
//
// }
//
// @BeforeEach
// public void setUpEach() {
// matchupList = new ArrayList<>();
// }
//
// @Test
// public void testIterationsInRealGame() {
// for (int i = 0; i < matchups.size(); i++) {
// testAIVSAI(games, getMatchup(i));
// }
// }
//
//
// private void testAIVSAI(int games, ArtificialPlayer[] ais) {
//
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int i = 0; i < games; i++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(ais);
//
// List<Integer> iterations1 = new ArrayList<>();
// List<Integer> iterations2 = new ArrayList<>();
//
// while (!match.isTerminal()) {
// final int currentAI = match.getCurrentTurn();
// final long move = ais[currentAI].getMove(match);
// if (ais[currentAI].getAi() instanceof MCTSAI) {
// final int lastIterations = ((MCTSAI) ais[currentAI].getAi()).getLastIterations();
// if (currentAI == 0) {
// iterations1.add(lastIterations);
// } else if (currentAI == 1) {
// iterations2.add(lastIterations);
// }
// }
// match.play(move);
// }
// int winner = match.getWinner();
// iterations1.addFirst(winner);
//// iterations1.add(-999);
// iterations1.addAll(iterations2);
//
// gamesList.add(iterations1);
// }
// matchupList.add(gamesList);
// }
//
// @Test
// public void testIterationsAtFixedMove() {
// for (ArtificialPlayer[] matchup : matchups) {
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int j = 0; j < games; j++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(matchup);
//
// List<Integer> iterations = new ArrayList<>();
//
// for (Long move : fixedMoveSet) {
// match.play(move);
// if (move == 32L) {
// break;
// }
// }
//// iterations.add(-999);
// var player = matchup[match.getCurrentTurn()];
// for (int k = 0; k < 10; k++) {
// player.getMove(match);
// if (player.getAi() instanceof MCTSAI) {
// iterations.add(((MCTSAI) player.getAi()).getLastIterations());
// }
// }
// gamesList.add(iterations);
// }
// matchupList.add(gamesList);
// }
// }
//
//
// @Test
// public void testIterationsInFixedGame() {
// for (ArtificialPlayer[] matchup : matchups) {
// List<List<Integer>> gamesList = new ArrayList<>();
// for (int j = 0; j < games; j++) {
// final BitboardReversi match = new BitboardReversi();
// match.init(matchup);
//
// List<Integer> iterations = new ArrayList<>();
//
// iterations.add(-999);
//
// for (Long move : fixedMoveSet) {
// var player = matchup[match.getCurrentTurn()];
// player.getMove(match);
// if (player.getAi() instanceof MCTSAI) {
// iterations.add(((MCTSAI) player.getAi()).getLastIterations());
// }
// match.play(move);
// }
//
// gamesList.add(iterations);
// }
// matchupList.add(gamesList);
// }
// }
//
// @AfterEach
// public void tearDown() {
// data.add(matchupList);
// }
//
// @AfterAll
// public static void writeAfterTests() {
// try {
// writeToCsv("Data.csv", data);
// } catch (IOException e) {
//
// }
// }
//
//
// public static void writeToCsv(String filepath, List<List<List<List<Integer>>>> data) throws IOException {
// try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
//
// writer.write("TestID,Matchup,GameNr,Winner");
// for (int i = 0; i < data.size(); i++) {
// writer.write(",Iterations");
// }
//
// writer.newLine();
//
// for (int TestID = 0; TestID < data.size(); TestID++) {
// List<List<List<Integer>>> testCase = data.get(TestID);
//
// for (int matchupNr = 0; matchupNr < testCase.size(); matchupNr++) {
// List<List<Integer>> matchup = testCase.get(matchupNr);
//
// for (int gameNr = 0; gameNr < matchup.size(); gameNr++) {
// List<Integer> game = matchup.get(gameNr);
// writer.write((TestID + 1) + "," + (getMatchupName(matchupNr)) + "," + (gameNr + 1));
// for (int i = 0; i < game.size(); i++) {
// if (i == 0) {
// writer.write("," + getWinnerFromMatchup(game.get(i), matchupNr));
// } else {
// writer.write("," + game.get(i));
// }
// }
// writer.newLine();
// }
// }
// }
// }
//
// }
//
//
// private static final List<List<List<List<Integer>>>> data = new ArrayList<>();
// private List<List<List<Integer>>> matchupList = new ArrayList<>();
// private static final List<String> matchupNames = new ArrayList<>();
// private static final List<ArtificialPlayer[]> matchups = new ArrayList<>();
//
// private static String getMatchupName(int matchupNr) {
// return matchupNames.get(matchupNr);
// }
//
// private static ArtificialPlayer[] getMatchup(int matchupNr) {
// return matchups.get(matchupNr);
// }
//
// private static String getWinnerFromMatchup(Integer winner, int matchupNr) {
// String matchup = matchupNames.get(matchupNr);
//
// String[] parts = matchup.split(" vs ");
//
// if (parts.length != 2) {
// return "Invalid matchup formatting.";
// }
//
// return winner == 0 ? parts[0] : winner == 1 ? parts[1] : winner == -999 ? "NVT" : "Tie";
// }
//
// private static void addMatchup(ArtificialPlayer player1, ArtificialPlayer player2) {
// matchups.add(new ArtificialPlayer[]{player1, player2});
// matchupNames.add(player1.getName() + " vs " + player2.getName());
// }
//}
// private final Long[] fixedMoveSet = new Long[]{17592186044416L,
// 35184372088832L,
// 67108864L,
// 8796093022208L,
// 2251799813685248L,
// 288230376151711744L,
// 70368744177664L,
// 1125899906842624L,
// 137438953472L,
// 140737488355328L,
// 4503599627370496L,
// 2305843009213693952L,
// 18014398509481984L,
// 274877906944L,
// 576460752303423488L,
// -9223372036854775808L,
// 549755813888L,
// 1152921504606846976L,
// 144115188075855872L,
// 72057594037927936L,
// 36028797018963968L,
// 17179869184L,
// 2199023255552L,
// 1048576L,
// 4398046511104L,
// 281474976710656L,
// 9007199254740992L,
// 2147483648L,
// 1073741824L,
// 33554432L,
// 262144L,
// 8388608L,
// 8192L,
// 4611686018427387904L,
// 4294967296L,
// 524288L,
// 4096L,
// 16777216L,
// 65536L,
// 32L,
// 2048L,
// 8L,
// 4L,
// 8589934592L,
// 16L,
// 2097152L,
// 4194304L,
// 1024L,
// 512L,
// 16384L,
// 536870912L,
// 1099511627776L,
// 64L,
// 562949953421312L,
// 128L,
// 1L,
// 32768L,
// 2L,
// 256L,
// 131072L};
// }

View File

@@ -1,4 +0,0 @@
package research;
public record GameData(String AI1, String AI2, String winner, int turns, long millisecondsAI1, long millisecondsAI2) {}

View File

@@ -1,30 +0,0 @@
package research;
import org.toop.framework.game.players.ArtificialPlayer;
import java.util.ArrayList;
import java.util.List;
public class Matchup {
public ArtificialPlayer player1;
public ArtificialPlayer player2;
public Matchup(ArtificialPlayer player1, ArtificialPlayer player2) {
this.player1 = player1;
this.player2 = player2;
}
public Matchup() {}
public String toString() {
return player1.toString() + " VS " + player2.toString();
}
public ArtificialPlayer getPlayer1() {
return player1;
}
public ArtificialPlayer getPlayer2() {
return player2;
}
}