Machine learning for reversi.

performance improvements for reversi.getlegalmoves
This commit is contained in:
Ticho Hidding
2025-12-02 10:59:33 +01:00
parent 1a11827ba3
commit 7e913ff50f
11 changed files with 464 additions and 24 deletions

View File

@@ -2,6 +2,7 @@ package org.toop;
import org.toop.app.App; import org.toop.app.App;
import org.toop.framework.audio.*; import org.toop.framework.audio.*;
import org.toop.framework.machinelearning.NeuralNetwork;
import org.toop.framework.networking.NetworkingClientEventListener; import org.toop.framework.networking.NetworkingClientEventListener;
import org.toop.framework.networking.NetworkingClientManager; import org.toop.framework.networking.NetworkingClientManager;
import org.toop.framework.resource.ResourceLoader; import org.toop.framework.resource.ResourceLoader;
@@ -35,6 +36,10 @@ public final class Main {
).initListeners("medium-button-click.wav"); ).initListeners("medium-button-click.wav");
}).start(); }).start();
//NeuralNetwork nn = new NeuralNetwork();
//nn.init();
} }
} }

View File

@@ -15,6 +15,7 @@ import org.toop.game.reversi.ReversiAI;
import javafx.geometry.Pos; import javafx.geometry.Pos;
import javafx.scene.paint.Color; import javafx.scene.paint.Color;
import org.toop.game.reversi.ReversiAIML;
import java.awt.*; import java.awt.*;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
@@ -30,7 +31,7 @@ public final class ReversiGame {
private final BlockingQueue<Move> moveQueue; private final BlockingQueue<Move> moveQueue;
private final Reversi game; private final Reversi game;
private final ReversiAI ai; private final ReversiAIML ai;
private final GameView primary; private final GameView primary;
private final ReversiCanvas canvas; private final ReversiCanvas canvas;
@@ -46,7 +47,7 @@ public final class ReversiGame {
moveQueue = new LinkedBlockingQueue<>(); moveQueue = new LinkedBlockingQueue<>();
game = new Reversi(); game = new Reversi();
ai = new ReversiAI(); ai = new ReversiAIML();
isRunning = new AtomicBoolean(true); isRunning = new AtomicBoolean(true);
isPaused = new AtomicBoolean(false); isPaused = new AtomicBoolean(false);
@@ -324,7 +325,7 @@ public final class ReversiGame {
if (isLegalMove) { if (isLegalMove) {
moves = game.getFlipsForPotentialMove( moves = game.getFlipsForPotentialMove(
new Point(cellEntered%game.getColumnSize(),cellEntered/game.getRowSize()), new Point(cellEntered%game.getColumnSize(),cellEntered/game.getRowSize()),
game.getCurrentPlayer()); game.getCurrentPlayer(), game.getCurrentPlayer() == 'B'?'W':'B',game.makeBoardAGrid());
} }
canvas.drawHighlightDots(moves); canvas.drawHighlightDots(moves);
} }

View File

@@ -146,6 +146,18 @@
<artifactId>error_prone_annotations</artifactId> <artifactId>error_prone_annotations</artifactId>
<version>2.42.0</version> <version>2.42.0</version>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>1.0.0-M2.1</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.toop</groupId>
<artifactId>game</artifactId>
<version>0.1</version>
<scope>compile</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@@ -0,0 +1,182 @@
package org.toop.framework.machinelearning;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.toop.game.AI;
import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAISimple;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static java.lang.Math.abs;
public class NeuralNetwork {
private MultiLayerConfiguration conf;
private MultiLayerNetwork model;
private ReversiAI reversiAI = new ReversiAI();
private AI<Reversi> opponentRand = new ReversiAI();
private AI<Reversi> opponentSimple = new ReversiAISimple();
public NeuralNetwork() {}
public void init(){
conf = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.001))
.weightInit(WeightInit.XAVIER) //todo understand
.list()
.layer(new DenseLayer.Builder()
.nIn(64)
.nOut(128)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(128)
.nOut(64)
.activation(Activation.SOFTMAX)
.build())
.build();
model = new MultiLayerNetwork(conf);
IO.println(model.params());
loadModel();
IO.println(model.params());
model.init();
IO.println(model.summary());
model.setLearningRate(0.001);
trainingLoop();
saveModel();
}
public void saveModel(){
File modelFile = new File("reversi-model.zip");
try {
ModelSerializer.writeModel(model, modelFile, true);
}catch (Exception e){
e.printStackTrace();
}
}
public void loadModel(){
File modelFile = new File("reversi-model.zip");
try {
model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
} catch (IOException e) {
e.printStackTrace();
}
}
public void trainingLoop(){
int totalGames = 5000;
double epsilon = 0.1;
long start = System.nanoTime();
for (int game = 0; game<totalGames; game++){
Reversi reversi = new Reversi();
List<StateAction> gameHistory = new ArrayList<>();
GameState state = GameState.NORMAL;
double reward = 0;
while (state != GameState.DRAW && state != GameState.WIN){
char curr = reversi.getCurrentPlayer();
Move move;
if (curr == 'B') {
int[] input = reversi.getBoardInt();
if (Math.random() < epsilon) {
Move[] moves = reversi.getLegalMoves();
move = moves[(int) (Math.random() * moves.length - .5f)];
} else {
INDArray boardInput = Nd4j.create(new int[][]{input});
INDArray prediction = model.output(boardInput);
int location = pickLegalMove(prediction, reversi);
gameHistory.add(new StateAction(input, location));
move = new Move(location, reversi.getCurrentPlayer());
}
}else{
move = reversiAI.findBestMove(reversi,5);
}
state = reversi.play(move);
}
//IO.println(model.params());
Reversi.Score score = reversi.getScore();
int scoreDif = abs(score.player1Score() - score.player2Score());
if (score.player1Score() > score.player2Score()){
reward = 1 + ((scoreDif / 64.0) * 0.5);
}else if (score.player1Score() < score.player2Score()){
reward = -1 - ((scoreDif / 64.0) * 0.5);
}else{
reward = 0;
}
for (StateAction step : gameHistory){
trainFromHistory(step, reward);
}
//IO.println("Wr: " + (double)p1wins/(game+1) + " draws: " + draws);
if(game % 100 == 0){
IO.println("Completed game " + game + " | Reward: " + reward);
//IO.println(Arrays.toString(reversi.getBoardDouble()));
}
}
long end = System.nanoTime();
IO.println((end-start));
}
private boolean isInCorner(Move move){
return move.position() == 0 || move.position() == 7 || move.position() == 56 || move.position() == 63;
}
private int pickLegalMove(INDArray prediction, Reversi reversi){
double[] probs = prediction.toDoubleVector();
Move[] legalMoves = reversi.getLegalMoves();
if (legalMoves.length == 0) return -1;
int bestMove = legalMoves[0].position();
double bestVal = probs[bestMove];
for (Move move : legalMoves){
if (probs[move.position()] > bestVal){
bestMove = move.position();
bestVal = probs[bestMove];
}
}
return bestMove;
}
private void trainFromHistory(StateAction step, double reward){
double[] output = new double[64];
output[step.action] = reward;
DataSet ds = new DataSet(
Nd4j.create(new int[][] { step.state }),
Nd4j.create(new double[][] { output })
);
model.fit(ds);
}
}

View File

@@ -0,0 +1,10 @@
package org.toop.framework.machinelearning;
public class StateAction {
int[] state;
int action;
public StateAction(int[] state, int action) {
this.state = state;
this.action = action;
}
}

View File

@@ -99,6 +99,16 @@
<artifactId>error_prone_annotations</artifactId> <artifactId>error_prone_annotations</artifactId>
<version>2.42.0</version> <version>2.42.0</version>
</dependency> </dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
</dependencies> </dependencies>

View File

@@ -15,6 +15,7 @@ public final class Reversi extends TurnBasedGame {
private int movesTaken; private int movesTaken;
private Set<Point> filledCells = new HashSet<>(); private Set<Point> filledCells = new HashSet<>();
private Move[] mostRecentlyFlippedPieces; private Move[] mostRecentlyFlippedPieces;
private char[][] cachedBoard;
public record Score(int player1Score, int player2Score) {} public record Score(int player1Score, int player2Score) {}
@@ -37,6 +38,7 @@ public final class Reversi extends TurnBasedGame {
this.setBoard(new Move(35, 'B')); this.setBoard(new Move(35, 'B'));
this.setBoard(new Move(36, 'W')); this.setBoard(new Move(36, 'W'));
updateFilledCellsSet(); updateFilledCellsSet();
cachedBoard = makeBoardAGrid();
} }
private void updateFilledCellsSet() { private void updateFilledCellsSet() {
for (int i = 0; i < 64; i++) { for (int i = 0; i < 64; i++) {
@@ -49,11 +51,13 @@ public final class Reversi extends TurnBasedGame {
@Override @Override
public Move[] getLegalMoves() { public Move[] getLegalMoves() {
final ArrayList<Move> legalMoves = new ArrayList<>(); final ArrayList<Move> legalMoves = new ArrayList<>();
char[][] boardGrid = makeBoardAGrid(); char[][] boardGrid = cachedBoard;
char currentPlayer = (this.getCurrentTurn()==0) ? 'B' : 'W'; char currentPlayer = (this.getCurrentTurn()==0) ? 'B' : 'W';
Set<Point> adjCell = getAdjacentCells(boardGrid); char opponent = (currentPlayer=='W') ? 'B' : 'W';
Set<Point> adjCell = getAdjacentCells(boardGrid, opponent);
for (Point point : adjCell){ for (Point point : adjCell){
Move[] moves = getFlipsForPotentialMove(point,currentPlayer); Move[] moves = getFlipsForPotentialMove(point, currentPlayer, opponent, boardGrid);
int score = moves.length; int score = moves.length;
if (score > 0){ if (score > 0){
legalMoves.add(new Move(point.x + point.y * this.getRowSize(), currentPlayer)); legalMoves.add(new Move(point.x + point.y * this.getRowSize(), currentPlayer));
@@ -62,9 +66,10 @@ public final class Reversi extends TurnBasedGame {
return legalMoves.toArray(new Move[0]); return legalMoves.toArray(new Move[0]);
} }
private Set<Point> getAdjacentCells(char[][] boardGrid) { private Set<Point> getAdjacentCells(char[][] boardGrid, char opponent) {
Set<Point> possibleCells = new HashSet<>(); Set<Point> possibleCells = new HashSet<>();
for (Point point : filledCells) { //for every filled cell for (Point point : filledCells) { //for every filled cell
if (boardGrid[point.x][point.y] == opponent) {
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //check adjacent cells for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //check adjacent cells
for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { //orthogonally and diagonally for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { //orthogonally and diagonally
int newX = point.x + deltaColumn, newY = point.y + deltaRow; int newX = point.x + deltaColumn, newY = point.y + deltaRow;
@@ -78,17 +83,18 @@ public final class Reversi extends TurnBasedGame {
} }
} }
} }
}
return possibleCells; return possibleCells;
} }
public Move[] getFlipsForPotentialMove(Point point, char currentPlayer) { public Move[] getFlipsForPotentialMove(Point point, char currentPlayer, char opponent, char[][] boardGrid) {
final ArrayList<Move> movesToFlip = new ArrayList<>(); final ArrayList<Move> movesToFlip = new ArrayList<>();
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //for all directions for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //for all directions
for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { for (int deltaRow = -1; deltaRow <= 1; deltaRow++) {
if (deltaColumn == 0 && deltaRow == 0){ if (deltaColumn == 0 && deltaRow == 0){
continue; continue;
} }
Move[] moves = getFlipsInDirection(point,makeBoardAGrid(),currentPlayer,deltaColumn,deltaRow); Move[] moves = getFlipsInDirection(point, boardGrid, currentPlayer, opponent, deltaColumn, deltaRow);
if (moves != null) { //getFlipsInDirection if (moves != null) { //getFlipsInDirection
movesToFlip.addAll(Arrays.asList(moves)); movesToFlip.addAll(Arrays.asList(moves));
} }
@@ -97,8 +103,14 @@ public final class Reversi extends TurnBasedGame {
return movesToFlip.toArray(new Move[0]); return movesToFlip.toArray(new Move[0]);
} }
private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, int dirX, int dirY) { public Move[] getFlipsForPotentialMove(Move move) {
char opponent = getOpponent(currentPlayer); char curr = getCurrentPlayer();
char opp = getOpponent(curr);
Point point = new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize());
return getFlipsForPotentialMove(point, curr, opp, cachedBoard);
}
private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, char opponent, int dirX, int dirY) {
final ArrayList<Move> movesToFlip = new ArrayList<>(); final ArrayList<Move> movesToFlip = new ArrayList<>();
int x = point.x + dirX; int x = point.x + dirX;
int y = point.y + dirY; int y = point.y + dirY;
@@ -123,7 +135,7 @@ public final class Reversi extends TurnBasedGame {
return x >= 0 && x < this.getColumnSize() && y >= 0 && y < this.getRowSize(); return x >= 0 && x < this.getColumnSize() && y >= 0 && y < this.getRowSize();
} }
private char[][] makeBoardAGrid() { public char[][] makeBoardAGrid() {
char[][] boardGrid = new char[this.getRowSize()][this.getColumnSize()]; char[][] boardGrid = new char[this.getRowSize()][this.getColumnSize()];
for (int i = 0; i < 64; i++) { for (int i = 0; i < 64; i++) {
boardGrid[i / this.getRowSize()][i % this.getColumnSize()] = this.getBoard()[i]; //boardGrid[y -> row] [x -> column] boardGrid[i / this.getRowSize()][i % this.getColumnSize()] = this.getBoard()[i]; //boardGrid[y -> row] [x -> column]
@@ -133,6 +145,9 @@ public final class Reversi extends TurnBasedGame {
@Override @Override
public GameState play(Move move) { public GameState play(Move move) {
if (cachedBoard == null) {
cachedBoard = makeBoardAGrid();
}
Move[] legalMoves = getLegalMoves(); Move[] legalMoves = getLegalMoves();
boolean moveIsLegal = false; boolean moveIsLegal = false;
for (Move legalMove : legalMoves) { //check if the move is legal for (Move legalMove : legalMoves) { //check if the move is legal
@@ -145,13 +160,14 @@ public final class Reversi extends TurnBasedGame {
return null; return null;
} }
Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value()),move); Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value(),move.value() == 'B'? 'W': 'B',makeBoardAGrid()),move);
mostRecentlyFlippedPieces = moves; mostRecentlyFlippedPieces = moves;
this.setBoard(move); //place the move on the board this.setBoard(move); //place the move on the board
for (Move m : moves) { for (Move m : moves) {
this.setBoard(m); //flip the correct pieces on the board this.setBoard(m); //flip the correct pieces on the board
} }
filledCells.add(new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize())); filledCells.add(new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize()));
cachedBoard = makeBoardAGrid();
nextTurn(); nextTurn();
if (getLegalMoves().length == 0) { //skip the players turn when there are no legal moves if (getLegalMoves().length == 0) { //skip the players turn when there are no legal moves
skipMyTurn(); skipMyTurn();
@@ -172,7 +188,7 @@ public final class Reversi extends TurnBasedGame {
} }
private void skipMyTurn(){ private void skipMyTurn(){
IO.println("TURN " + getCurrentPlayer() + " SKIPPED"); //IO.println("TURN " + getCurrentPlayer() + " SKIPPED");
//TODO: notify user that a turn has been skipped //TODO: notify user that a turn has been skipped
nextTurn(); nextTurn();
} }
@@ -207,6 +223,32 @@ public final class Reversi extends TurnBasedGame {
} }
return new Score(player1Score, player2Score); return new Score(player1Score, player2Score);
} }
public boolean isGameOver(){
Move[] legalMovesW = getLegalMoves();
nextTurn();
Move[] legalMovesB = getLegalMoves();
nextTurn();
if (legalMovesW.length + legalMovesB.length == 0) {
return true;
}
return false;
}
public int getWinner(){
if (!isGameOver()) {
return 0;
}
Score score = getScore();
if (score.player1Score() > score.player2Score()) {
return 1;
}
else if (score.player1Score() < score.player2Score()) {
return 2;
}
return 0;
}
private Move[] sortMovesFromCenter(Move[] moves, Move center) { //sorts the pieces to be flipped for animation purposes private Move[] sortMovesFromCenter(Move[] moves, Move center) { //sorts the pieces to be flipped for animation purposes
int centerX = center.position()%this.getColumnSize(); int centerX = center.position()%this.getColumnSize();
int centerY = center.position()/this.getRowSize(); int centerY = center.position()/this.getRowSize();
@@ -226,4 +268,34 @@ public final class Reversi extends TurnBasedGame {
public Move[] getMostRecentlyFlippedPieces() { public Move[] getMostRecentlyFlippedPieces() {
return mostRecentlyFlippedPieces; return mostRecentlyFlippedPieces;
} }
public int[] getBoardInt(){
char[] input = getBoard();
int[] result = new int[input.length];
for (int i = 0; i < input.length; i++) {
switch (input[i]) {
case 'W':
result[i] = -1;
break;
case 'B':
result[i] = 1;
break;
case ' ':
default:
result[i] = 0;
break;
}
}
return result;
}
public Point moveToPoint(Move move){
return new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize());
}
public void printBoard(){
for (int row = 0; row < this.getRowSize(); row++) {
IO.println(Arrays.toString(cachedBoard[row]));
}
}
} }

View File

@@ -4,6 +4,7 @@ import org.toop.game.AI;
import org.toop.game.records.Move; import org.toop.game.records.Move;
public final class ReversiAI extends AI<Reversi> { public final class ReversiAI extends AI<Reversi> {
@Override @Override
public Move findBestMove(Reversi game, int depth) { public Move findBestMove(Reversi game, int depth) {
Move[] moves = game.getLegalMoves(); Move[] moves = game.getLegalMoves();

View File

@@ -0,0 +1,52 @@
package org.toop.game.reversi;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.toop.game.AI;
import org.toop.game.records.Move;
import java.io.IOException;
import java.io.InputStream;
public class ReversiAIML extends AI<Reversi>{
MultiLayerNetwork model;
public ReversiAIML() {
InputStream is = getClass().getResourceAsStream("/reversi-model.zip");
try {
assert is != null;
model = ModelSerializer.restoreMultiLayerNetwork(is);
} catch (IOException e) {}
}
public Move findBestMove(Reversi reversi, int depth){
int[] input = reversi.getBoardInt();
INDArray boardInput = Nd4j.create(new int[][] { input });
INDArray prediction = model.output(boardInput);
int move = pickLegalMove(prediction, reversi);
return new Move(move, reversi.getCurrentPlayer());
}
private int pickLegalMove(INDArray prediction, Reversi reversi){
double[] probs = prediction.toDoubleVector();
Move[] legalMoves = reversi.getLegalMoves();
if (legalMoves.length == 0) return -1;
int bestMove = legalMoves[0].position();
double bestVal = probs[bestMove];
for (Move move : legalMoves){
if (probs[move.position()] > bestVal){
bestMove = move.position();
bestVal = probs[bestMove];
}
}
return bestMove;
}
}

View File

@@ -0,0 +1,57 @@
package org.toop.game.reversi;
import org.toop.game.AI;
import org.toop.game.records.Move;
import java.util.Arrays;
public class ReversiAISimple extends AI<Reversi> {
@Override
public Move findBestMove(Reversi game, int depth) {
//IO.println("****START FIND BEST MOVE****");
Move[] moves = game.getLegalMoves();
//game.printBoard();
//IO.println("Legal moves: " + Arrays.toString(moves));
Move bestMove;
Move bestMoveScore = moves[0];
Move bestMoveOptions = moves[0];
int bestScore = -1;
int bestOptions = -1;
for (Move move : moves){
int numOpt = getNumberOfOptions(game, move);
if (numOpt > bestOptions) {
bestOptions = numOpt;
bestMoveOptions = move;
}
int numSco = getScore(game, move);
if (numSco > bestScore) {
bestScore = numSco;
bestMoveScore = move;
}
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
}
if (bestScore > bestOptions) {
bestMove = bestMoveScore;
}
else{
bestMove = bestMoveOptions;
}
return bestMove;
}
private int getNumberOfOptions(Reversi game, Move move){
Reversi copy = new Reversi(game);
copy.play(move);
return copy.getLegalMoves().length;
}
private int getScore(Reversi game, Move move){
return game.getFlipsForPotentialMove(move).length;
}
}

View File

@@ -8,17 +8,24 @@ import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move; import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi; import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI; import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAIML;
import org.toop.game.reversi.ReversiAISimple;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
class ReversiTest { class ReversiTest {
private Reversi game; private Reversi game;
private ReversiAI ai; private ReversiAI ai;
private ReversiAIML aiml;
private ReversiAISimple aiSimple;
@BeforeEach @BeforeEach
void setup() { void setup() {
game = new Reversi(); game = new Reversi();
ai = new ReversiAI(); ai = new ReversiAI();
aiml = new ReversiAIML();
aiSimple = new ReversiAISimple();
} }
@@ -190,4 +197,35 @@ class ReversiTest {
} }
return false; return false;
} }
@Test
void testAIvsAIML(){
IO.println("Testing AI simple ...");
int totalGames = 5000;
int p1wins = 0;
int p2wins = 0;
int draws = 0;
for (int i = 0; i < totalGames; i++) {
game = new Reversi();
while (!game.isGameOver()) {
char curr = game.getCurrentPlayer();
if (curr == 'W') {
game.play(ai.findBestMove(game,5));
}
else {
game.play(ai.findBestMove(game,5));
}
}
int winner = game.getWinner();
if (winner == 1) {
p1wins++;
}else if (winner == 2) {
p2wins++;
}
else{
draws++;
}
}
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws);
}
} }