diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index 02208f7..2a36d69 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -38,8 +38,8 @@ public final class Main { }).start(); - //NeuralNetwork nn = new NeuralNetwork(); - //nn.init(); + NeuralNetwork nn = new NeuralNetwork(); + nn.init(); } } diff --git a/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java b/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java index 3603694..9a7af90 100644 --- a/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java +++ b/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java @@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState; import org.toop.game.records.Move; import org.toop.game.reversi.Reversi; import org.toop.game.reversi.ReversiAI; +import org.toop.game.reversi.ReversiAIML; import org.toop.game.reversi.ReversiAISimple; import java.io.File; @@ -26,14 +27,17 @@ import java.util.ArrayList; import java.util.List; import static java.lang.Math.abs; +import static java.lang.Math.random; public class NeuralNetwork { private MultiLayerConfiguration conf; private MultiLayerNetwork model; + private AI opponentAI; private ReversiAI reversiAI = new ReversiAI(); private AI opponentRand = new ReversiAI(); private AI opponentSimple = new ReversiAISimple(); + private AI opponentAIML = new ReversiAIML(); public NeuralNetwork() {} @@ -61,7 +65,7 @@ public class NeuralNetwork { model.init(); IO.println(model.summary()); - model.setLearningRate(0.001); + model.setLearningRate(0.0003); trainingLoop(); saveModel(); } @@ -85,13 +89,15 @@ public class NeuralNetwork { } public void trainingLoop(){ - int totalGames = 5000; - double epsilon = 0.1; + int totalGames = 50000; + double epsilon = 0.05; long start = System.nanoTime(); for (int game = 0; game gameHistory = new ArrayList<>(); GameState state = GameState.NORMAL; @@ -100,7 +106,7 @@ public class NeuralNetwork { while (state != GameState.DRAW && state != GameState.WIN){ char curr = reversi.getCurrentPlayer(); Move move; - if (curr == 'B') { + if (curr == modelPlayer) { int[] input = reversi.getBoardInt(); if (Math.random() < epsilon) { Move[] moves = reversi.getLegalMoves(); @@ -114,7 +120,7 @@ public class NeuralNetwork { move = new Move(location, reversi.getCurrentPlayer()); } }else{ - move = reversiAI.findBestMove(reversi,5); + move = opponentAI.findBestMove(reversi,5); } state = reversi.play(move); } @@ -130,6 +136,10 @@ public class NeuralNetwork { reward = 0; } + if (modelPlayer == 'W'){ + reward = -reward; + } + for (StateAction step : gameHistory){ trainFromHistory(step, reward); @@ -167,6 +177,15 @@ public class NeuralNetwork { return bestMove; } + private AI getOpponentAI(){ + return switch ((int) (Math.random() * 4)) { + case 0 -> opponentRand; + case 1 -> opponentSimple; + case 2 -> opponentAIML; + default -> opponentRand; + }; + } + private void trainFromHistory(StateAction step, double reward){ double[] output = new double[64]; output[step.action] = reward; diff --git a/game/src/main/java/org/toop/game/reversi/ReversiAIML.java b/game/src/main/java/org/toop/game/reversi/ReversiAIML.java index 78f6915..6874f03 100644 --- a/game/src/main/java/org/toop/game/reversi/ReversiAIML.java +++ b/game/src/main/java/org/toop/game/reversi/ReversiAIML.java @@ -10,6 +10,8 @@ import org.toop.game.records.Move; import java.io.IOException; import java.io.InputStream; +import static java.lang.Math.random; + public class ReversiAIML extends AI{ MultiLayerNetwork model; @@ -28,23 +30,27 @@ public class ReversiAIML extends AI{ INDArray boardInput = Nd4j.create(new int[][] { input }); INDArray prediction = model.output(boardInput); - int move = pickLegalMove(prediction, reversi); + int move = pickLegalMove(prediction,reversi); return new Move(move, reversi.getCurrentPlayer()); } - private int pickLegalMove(INDArray prediction, Reversi reversi){ - double[] probs = prediction.toDoubleVector(); + private int pickLegalMove(INDArray prediction, Reversi reversi) { + double[] logits = prediction.toDoubleVector(); Move[] legalMoves = reversi.getLegalMoves(); if (legalMoves.length == 0) return -1; int bestMove = legalMoves[0].position(); - double bestVal = probs[bestMove]; + double bestVal = logits[bestMove]; - for (Move move : legalMoves){ - if (probs[move.position()] > bestVal){ - bestMove = move.position(); - bestVal = probs[bestMove]; + if (random() < 0.01){ + return legalMoves[(int)(random()*legalMoves.length-.5)].position(); + } + for (Move move : legalMoves) { + int pos = move.position(); + if (logits[pos] > bestVal) { + bestMove = pos; + bestVal = logits[pos]; } } return bestMove; diff --git a/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java b/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java index 109dcc9..d5b9fe2 100644 --- a/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java +++ b/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java @@ -33,6 +33,12 @@ public class ReversiAISimple extends AI { bestScore = numSco; bestMoveScore = move; } + if (numSco == bestScore || numOpt == bestOptions) { + if (Math.random() < 0.5) { + bestMoveOptions = move; + bestMoveScore = move; + } + } //IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco); } diff --git a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java index 4f2c80f..25daa26 100644 --- a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java +++ b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java @@ -4,6 +4,7 @@ import java.util.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.toop.game.AI; import org.toop.game.enumerators.GameState; import org.toop.game.records.Move; import org.toop.game.reversi.Reversi; @@ -18,6 +19,8 @@ class ReversiTest { private ReversiAI ai; private ReversiAIML aiml; private ReversiAISimple aiSimple; + private AI player1; + private AI player2; @BeforeEach void setup() { @@ -32,20 +35,20 @@ class ReversiTest { @Test void testCorrectStartPiecesPlaced() { assertNotNull(game); - assertEquals('W',game.getBoard()[27]); - assertEquals('B',game.getBoard()[28]); - assertEquals('B',game.getBoard()[35]); - assertEquals('W',game.getBoard()[36]); + assertEquals('W', game.getBoard()[27]); + assertEquals('B', game.getBoard()[28]); + assertEquals('B', game.getBoard()[35]); + assertEquals('W', game.getBoard()[36]); } @Test void testGetLegalMovesAtStart() { Move[] moves = game.getLegalMoves(); List expectedMoves = List.of( - new Move(19,'B'), - new Move(26,'B'), - new Move(37,'B'), - new Move(44,'B') + new Move(19, 'B'), + new Move(26, 'B'), + new Move(37, 'B'), + new Move(44, 'B') ); assertNotNull(moves); assertTrue(moves.length > 0); @@ -83,12 +86,12 @@ class ReversiTest { @Test void testCountScoreCorrectlyAtStart() { - long start = System.nanoTime(); + long start = System.nanoTime(); Reversi.Score score = game.getScore(); assertEquals(2, score.player1Score()); // Black assertEquals(2, score.player2Score()); // White - long end = System.nanoTime(); - IO.println((end-start)); + long end = System.nanoTime(); + IO.println((end - start)); } @Test @@ -97,7 +100,7 @@ class ReversiTest { game.play(new Move(20, 'W')); Move[] moves = game.getLegalMoves(); List expectedMoves = List.of( - new Move(13,'B'), + new Move(13, 'B'), new Move(21, 'B'), new Move(29, 'B'), new Move(37, 'B'), @@ -110,11 +113,11 @@ class ReversiTest { @Test void testCountScoreCorrectlyAtEnd() { - for (int i = 0; i < 1; i++){ - game = new Reversi(); + for (int i = 0; i < 1; i++) { + game = new Reversi(); Move[] legalMoves = game.getLegalMoves(); - while(legalMoves.length > 0) { - game.play(legalMoves[(int)(Math.random()*legalMoves.length)]); + while (legalMoves.length > 0) { + game.play(legalMoves[(int) (Math.random() * legalMoves.length)]); legalMoves = game.getLegalMoves(); } Reversi.Score score = game.getScore(); @@ -152,8 +155,8 @@ class ReversiTest { game.play(new Move(60, 'W')); game.play(new Move(59, 'B')); assertEquals('B', game.getCurrentPlayer()); - game.play(ai.findBestMove(game,5)); - game.play(ai.findBestMove(game,5)); + game.play(ai.findBestMove(game, 5)); + game.play(ai.findBestMove(game, 5)); } @Test @@ -186,9 +189,9 @@ class ReversiTest { @Test void testAISelectsLegalMove() { - Move move = ai.findBestMove(game,4); + Move move = ai.findBestMove(game, 4); assertNotNull(move); - assertTrue(containsMove(game.getLegalMoves(),move), "AI should always choose a legal move"); + assertTrue(containsMove(game.getLegalMoves(), move), "AI should always choose a legal move"); } private boolean containsMove(Move[] moves, Move move) { @@ -199,33 +202,60 @@ class ReversiTest { } @Test - void testAIvsAIML(){ - IO.println("Testing AI simple ..."); - int totalGames = 5000; + void testAis() { + player1 = aiml; + player2 = ai; + testAIvsAIML(); + player2 = aiSimple; + testAIvsAIML(); + player1 = ai; + testAIvsAIML(); + player2 = aiml; + testAIvsAIML(); + player1 = aiml; + testAIvsAIML(); + player1 = aiSimple; + testAIvsAIML(); + } + + @Test + void testAIvsAIML() { + if(player1 == null || player2 == null) { + player1 = aiml; + player2 = ai; + } + int totalGames = 2000; + IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games"); int p1wins = 0; int p2wins = 0; int draws = 0; + List moves = new ArrayList<>(); for (int i = 0; i < totalGames; i++) { - game = new Reversi(); + game = new Reversi(); while (!game.isGameOver()) { char curr = game.getCurrentPlayer(); - if (curr == 'W') { - game.play(ai.findBestMove(game,5)); - } - else { - game.play(ai.findBestMove(game,5)); + Move move; + if (curr == 'B') { + move = player1.findBestMove(game, 5); + } else { + move = player2.findBestMove(game, 5); } + if (i%500 == 0) moves.add(move.position()); + game.play(move); + } + if (i%500 == 0) { + IO.println(moves); + moves.clear(); } int winner = game.getWinner(); if (winner == 1) { p1wins++; - }else if (winner == 2) { + } else if (winner == 2) { p2wins++; - } - else{ + } else { draws++; } } - IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws); + IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames + "\np2wins: " + p2wins + " draws: " + draws); } -} +} \ No newline at end of file