diff --git a/app/src/main/java/org/toop/Main.java b/app/src/main/java/org/toop/Main.java index 657f9bf..02208f7 100644 --- a/app/src/main/java/org/toop/Main.java +++ b/app/src/main/java/org/toop/Main.java @@ -2,6 +2,7 @@ package org.toop; import org.toop.app.App; import org.toop.framework.audio.*; +import org.toop.framework.machinelearning.NeuralNetwork; import org.toop.framework.networking.NetworkingClientEventListener; import org.toop.framework.networking.NetworkingClientManager; import org.toop.framework.resource.ResourceLoader; @@ -35,6 +36,10 @@ public final class Main { ).initListeners("medium-button-click.wav"); }).start(); + + + //NeuralNetwork nn = new NeuralNetwork(); + //nn.init(); } } diff --git a/app/src/main/java/org/toop/app/game/ReversiGame.java b/app/src/main/java/org/toop/app/game/ReversiGame.java index 5e472f5..f274f5a 100644 --- a/app/src/main/java/org/toop/app/game/ReversiGame.java +++ b/app/src/main/java/org/toop/app/game/ReversiGame.java @@ -15,6 +15,7 @@ import org.toop.game.reversi.ReversiAI; import javafx.geometry.Pos; import javafx.scene.paint.Color; +import org.toop.game.reversi.ReversiAIML; import java.awt.*; import java.util.concurrent.BlockingQueue; @@ -30,7 +31,7 @@ public final class ReversiGame { private final BlockingQueue moveQueue; private final Reversi game; - private final ReversiAI ai; + private final ReversiAIML ai; private final GameView primary; private final ReversiCanvas canvas; @@ -46,7 +47,7 @@ public final class ReversiGame { moveQueue = new LinkedBlockingQueue<>(); game = new Reversi(); - ai = new ReversiAI(); + ai = new ReversiAIML(); isRunning = new AtomicBoolean(true); isPaused = new AtomicBoolean(false); @@ -324,7 +325,7 @@ public final class ReversiGame { if (isLegalMove) { moves = game.getFlipsForPotentialMove( new Point(cellEntered%game.getColumnSize(),cellEntered/game.getRowSize()), - game.getCurrentPlayer()); + game.getCurrentPlayer(), game.getCurrentPlayer() == 'B'?'W':'B',game.makeBoardAGrid()); } canvas.drawHighlightDots(moves); } diff --git a/framework/pom.xml b/framework/pom.xml index b5796c5..d215eb6 100644 --- a/framework/pom.xml +++ b/framework/pom.xml @@ -146,7 +146,19 @@ error_prone_annotations 2.42.0 - + + org.deeplearning4j + deeplearning4j-nn + 1.0.0-M2.1 + compile + + + org.toop + game + 0.1 + compile + + diff --git a/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java b/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java new file mode 100644 index 0000000..3603694 --- /dev/null +++ b/framework/src/main/java/org/toop/framework/machinelearning/NeuralNetwork.java @@ -0,0 +1,182 @@ +package org.toop.framework.machinelearning; + +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.toop.game.AI; +import org.toop.game.enumerators.GameState; +import org.toop.game.records.Move; +import org.toop.game.reversi.Reversi; +import org.toop.game.reversi.ReversiAI; +import org.toop.game.reversi.ReversiAISimple; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.lang.Math.abs; + +public class NeuralNetwork { + + private MultiLayerConfiguration conf; + private MultiLayerNetwork model; + private ReversiAI reversiAI = new ReversiAI(); + private AI opponentRand = new ReversiAI(); + private AI opponentSimple = new ReversiAISimple(); + + + public NeuralNetwork() {} + + public void init(){ + conf = new NeuralNetConfiguration.Builder() + .updater(new Adam(0.001)) + .weightInit(WeightInit.XAVIER) //todo understand + .list() + .layer(new DenseLayer.Builder() + .nIn(64) + .nOut(128) + .activation(Activation.RELU) + .build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .nIn(128) + .nOut(64) + .activation(Activation.SOFTMAX) + .build()) + .build(); + model = new MultiLayerNetwork(conf); + IO.println(model.params()); + loadModel(); + IO.println(model.params()); + model.init(); + IO.println(model.summary()); + + model.setLearningRate(0.001); + trainingLoop(); + saveModel(); + } + + public void saveModel(){ + File modelFile = new File("reversi-model.zip"); + try { + ModelSerializer.writeModel(model, modelFile, true); + }catch (Exception e){ + e.printStackTrace(); + } + } + + public void loadModel(){ + File modelFile = new File("reversi-model.zip"); + try { + model = ModelSerializer.restoreMultiLayerNetwork(modelFile); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public void trainingLoop(){ + int totalGames = 5000; + double epsilon = 0.1; + + long start = System.nanoTime(); + + for (int game = 0; game gameHistory = new ArrayList<>(); + GameState state = GameState.NORMAL; + + double reward = 0; + + while (state != GameState.DRAW && state != GameState.WIN){ + char curr = reversi.getCurrentPlayer(); + Move move; + if (curr == 'B') { + int[] input = reversi.getBoardInt(); + if (Math.random() < epsilon) { + Move[] moves = reversi.getLegalMoves(); + move = moves[(int) (Math.random() * moves.length - .5f)]; + } else { + INDArray boardInput = Nd4j.create(new int[][]{input}); + INDArray prediction = model.output(boardInput); + + int location = pickLegalMove(prediction, reversi); + gameHistory.add(new StateAction(input, location)); + move = new Move(location, reversi.getCurrentPlayer()); + } + }else{ + move = reversiAI.findBestMove(reversi,5); + } + state = reversi.play(move); + } + + //IO.println(model.params()); + Reversi.Score score = reversi.getScore(); + int scoreDif = abs(score.player1Score() - score.player2Score()); + if (score.player1Score() > score.player2Score()){ + reward = 1 + ((scoreDif / 64.0) * 0.5); + }else if (score.player1Score() < score.player2Score()){ + reward = -1 - ((scoreDif / 64.0) * 0.5); + }else{ + reward = 0; + } + + + for (StateAction step : gameHistory){ + trainFromHistory(step, reward); + } + + //IO.println("Wr: " + (double)p1wins/(game+1) + " draws: " + draws); + if(game % 100 == 0){ + IO.println("Completed game " + game + " | Reward: " + reward); + //IO.println(Arrays.toString(reversi.getBoardDouble())); + } + } + long end = System.nanoTime(); + IO.println((end-start)); + } + + private boolean isInCorner(Move move){ + return move.position() == 0 || move.position() == 7 || move.position() == 56 || move.position() == 63; + } + + private int pickLegalMove(INDArray prediction, Reversi reversi){ + double[] probs = prediction.toDoubleVector(); + Move[] legalMoves = reversi.getLegalMoves(); + + if (legalMoves.length == 0) return -1; + + int bestMove = legalMoves[0].position(); + double bestVal = probs[bestMove]; + + for (Move move : legalMoves){ + if (probs[move.position()] > bestVal){ + bestMove = move.position(); + bestVal = probs[bestMove]; + } + } + return bestMove; + } + + private void trainFromHistory(StateAction step, double reward){ + double[] output = new double[64]; + output[step.action] = reward; + + DataSet ds = new DataSet( + Nd4j.create(new int[][] { step.state }), + Nd4j.create(new double[][] { output }) + ); + + model.fit(ds); + + } +} diff --git a/framework/src/main/java/org/toop/framework/machinelearning/StateAction.java b/framework/src/main/java/org/toop/framework/machinelearning/StateAction.java new file mode 100644 index 0000000..58a7ffb --- /dev/null +++ b/framework/src/main/java/org/toop/framework/machinelearning/StateAction.java @@ -0,0 +1,10 @@ +package org.toop.framework.machinelearning; + +public class StateAction { + int[] state; + int action; + public StateAction(int[] state, int action) { + this.state = state; + this.action = action; + } +} diff --git a/game/pom.xml b/game/pom.xml index 6785e1c..e4c7f8a 100644 --- a/game/pom.xml +++ b/game/pom.xml @@ -99,6 +99,16 @@ error_prone_annotations 2.42.0 + + org.deeplearning4j + deeplearning4j-core + 1.0.0-M2.1 + + + org.nd4j + nd4j-native-platform + 1.0.0-M2.1 + diff --git a/game/src/main/java/org/toop/game/reversi/Reversi.java b/game/src/main/java/org/toop/game/reversi/Reversi.java index a13a44e..42e0485 100644 --- a/game/src/main/java/org/toop/game/reversi/Reversi.java +++ b/game/src/main/java/org/toop/game/reversi/Reversi.java @@ -15,6 +15,7 @@ public final class Reversi extends TurnBasedGame { private int movesTaken; private Set filledCells = new HashSet<>(); private Move[] mostRecentlyFlippedPieces; + private char[][] cachedBoard; public record Score(int player1Score, int player2Score) {} @@ -37,6 +38,7 @@ public final class Reversi extends TurnBasedGame { this.setBoard(new Move(35, 'B')); this.setBoard(new Move(36, 'W')); updateFilledCellsSet(); + cachedBoard = makeBoardAGrid(); } private void updateFilledCellsSet() { for (int i = 0; i < 64; i++) { @@ -49,11 +51,13 @@ public final class Reversi extends TurnBasedGame { @Override public Move[] getLegalMoves() { final ArrayList legalMoves = new ArrayList<>(); - char[][] boardGrid = makeBoardAGrid(); + char[][] boardGrid = cachedBoard; char currentPlayer = (this.getCurrentTurn()==0) ? 'B' : 'W'; - Set adjCell = getAdjacentCells(boardGrid); + char opponent = (currentPlayer=='W') ? 'B' : 'W'; + + Set adjCell = getAdjacentCells(boardGrid, opponent); for (Point point : adjCell){ - Move[] moves = getFlipsForPotentialMove(point,currentPlayer); + Move[] moves = getFlipsForPotentialMove(point, currentPlayer, opponent, boardGrid); int score = moves.length; if (score > 0){ legalMoves.add(new Move(point.x + point.y * this.getRowSize(), currentPlayer)); @@ -62,18 +66,20 @@ public final class Reversi extends TurnBasedGame { return legalMoves.toArray(new Move[0]); } - private Set getAdjacentCells(char[][] boardGrid) { + private Set getAdjacentCells(char[][] boardGrid, char opponent) { Set possibleCells = new HashSet<>(); for (Point point : filledCells) { //for every filled cell - for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++){ //check adjacent cells - for (int deltaRow = -1; deltaRow <= 1; deltaRow++){ //orthogonally and diagonally - int newX = point.x + deltaColumn, newY = point.y + deltaRow; - if (deltaColumn == 0 && deltaRow == 0 //continue if out of bounds - || !isOnBoard(newX, newY)) { - continue; - } - if (boardGrid[newY][newX] == EMPTY) { //check if the cell is empty - possibleCells.add(new Point(newX, newY)); //and then add it to the set of possible moves + if (boardGrid[point.x][point.y] == opponent) { + for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //check adjacent cells + for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { //orthogonally and diagonally + int newX = point.x + deltaColumn, newY = point.y + deltaRow; + if (deltaColumn == 0 && deltaRow == 0 //continue if out of bounds + || !isOnBoard(newX, newY)) { + continue; + } + if (boardGrid[newY][newX] == EMPTY) { //check if the cell is empty + possibleCells.add(new Point(newX, newY)); //and then add it to the set of possible moves + } } } } @@ -81,14 +87,14 @@ public final class Reversi extends TurnBasedGame { return possibleCells; } - public Move[] getFlipsForPotentialMove(Point point, char currentPlayer) { + public Move[] getFlipsForPotentialMove(Point point, char currentPlayer, char opponent, char[][] boardGrid) { final ArrayList movesToFlip = new ArrayList<>(); for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //for all directions for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { if (deltaColumn == 0 && deltaRow == 0){ continue; } - Move[] moves = getFlipsInDirection(point,makeBoardAGrid(),currentPlayer,deltaColumn,deltaRow); + Move[] moves = getFlipsInDirection(point, boardGrid, currentPlayer, opponent, deltaColumn, deltaRow); if (moves != null) { //getFlipsInDirection movesToFlip.addAll(Arrays.asList(moves)); } @@ -97,8 +103,14 @@ public final class Reversi extends TurnBasedGame { return movesToFlip.toArray(new Move[0]); } - private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, int dirX, int dirY) { - char opponent = getOpponent(currentPlayer); + public Move[] getFlipsForPotentialMove(Move move) { + char curr = getCurrentPlayer(); + char opp = getOpponent(curr); + Point point = new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize()); + return getFlipsForPotentialMove(point, curr, opp, cachedBoard); + } + + private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, char opponent, int dirX, int dirY) { final ArrayList movesToFlip = new ArrayList<>(); int x = point.x + dirX; int y = point.y + dirY; @@ -123,7 +135,7 @@ public final class Reversi extends TurnBasedGame { return x >= 0 && x < this.getColumnSize() && y >= 0 && y < this.getRowSize(); } - private char[][] makeBoardAGrid() { + public char[][] makeBoardAGrid() { char[][] boardGrid = new char[this.getRowSize()][this.getColumnSize()]; for (int i = 0; i < 64; i++) { boardGrid[i / this.getRowSize()][i % this.getColumnSize()] = this.getBoard()[i]; //boardGrid[y -> row] [x -> column] @@ -133,6 +145,9 @@ public final class Reversi extends TurnBasedGame { @Override public GameState play(Move move) { + if (cachedBoard == null) { + cachedBoard = makeBoardAGrid(); + } Move[] legalMoves = getLegalMoves(); boolean moveIsLegal = false; for (Move legalMove : legalMoves) { //check if the move is legal @@ -145,13 +160,14 @@ public final class Reversi extends TurnBasedGame { return null; } - Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value()),move); + Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value(),move.value() == 'B'? 'W': 'B',makeBoardAGrid()),move); mostRecentlyFlippedPieces = moves; this.setBoard(move); //place the move on the board for (Move m : moves) { this.setBoard(m); //flip the correct pieces on the board } filledCells.add(new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize())); + cachedBoard = makeBoardAGrid(); nextTurn(); if (getLegalMoves().length == 0) { //skip the players turn when there are no legal moves skipMyTurn(); @@ -172,7 +188,7 @@ public final class Reversi extends TurnBasedGame { } private void skipMyTurn(){ - IO.println("TURN " + getCurrentPlayer() + " SKIPPED"); + //IO.println("TURN " + getCurrentPlayer() + " SKIPPED"); //TODO: notify user that a turn has been skipped nextTurn(); } @@ -207,6 +223,32 @@ public final class Reversi extends TurnBasedGame { } return new Score(player1Score, player2Score); } + + public boolean isGameOver(){ + Move[] legalMovesW = getLegalMoves(); + nextTurn(); + Move[] legalMovesB = getLegalMoves(); + nextTurn(); + if (legalMovesW.length + legalMovesB.length == 0) { + return true; + } + return false; + } + + public int getWinner(){ + if (!isGameOver()) { + return 0; + } + Score score = getScore(); + if (score.player1Score() > score.player2Score()) { + return 1; + } + else if (score.player1Score() < score.player2Score()) { + return 2; + } + return 0; + } + private Move[] sortMovesFromCenter(Move[] moves, Move center) { //sorts the pieces to be flipped for animation purposes int centerX = center.position()%this.getColumnSize(); int centerY = center.position()/this.getRowSize(); @@ -226,4 +268,34 @@ public final class Reversi extends TurnBasedGame { public Move[] getMostRecentlyFlippedPieces() { return mostRecentlyFlippedPieces; } + + public int[] getBoardInt(){ + char[] input = getBoard(); + int[] result = new int[input.length]; + for (int i = 0; i < input.length; i++) { + switch (input[i]) { + case 'W': + result[i] = -1; + break; + case 'B': + result[i] = 1; + break; + case ' ': + default: + result[i] = 0; + break; + } + } + return result; + } + + public Point moveToPoint(Move move){ + return new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()); + } + + public void printBoard(){ + for (int row = 0; row < this.getRowSize(); row++) { + IO.println(Arrays.toString(cachedBoard[row])); + } + } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/reversi/ReversiAI.java b/game/src/main/java/org/toop/game/reversi/ReversiAI.java index 2fc78d5..8e64509 100644 --- a/game/src/main/java/org/toop/game/reversi/ReversiAI.java +++ b/game/src/main/java/org/toop/game/reversi/ReversiAI.java @@ -4,6 +4,7 @@ import org.toop.game.AI; import org.toop.game.records.Move; public final class ReversiAI extends AI { + @Override public Move findBestMove(Reversi game, int depth) { Move[] moves = game.getLegalMoves(); diff --git a/game/src/main/java/org/toop/game/reversi/ReversiAIML.java b/game/src/main/java/org/toop/game/reversi/ReversiAIML.java new file mode 100644 index 0000000..78f6915 --- /dev/null +++ b/game/src/main/java/org/toop/game/reversi/ReversiAIML.java @@ -0,0 +1,52 @@ +package org.toop.game.reversi; + +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.toop.game.AI; +import org.toop.game.records.Move; + +import java.io.IOException; +import java.io.InputStream; + +public class ReversiAIML extends AI{ + + MultiLayerNetwork model; + + public ReversiAIML() { + InputStream is = getClass().getResourceAsStream("/reversi-model.zip"); + try { + assert is != null; + model = ModelSerializer.restoreMultiLayerNetwork(is); + } catch (IOException e) {} + } + + public Move findBestMove(Reversi reversi, int depth){ + int[] input = reversi.getBoardInt(); + + INDArray boardInput = Nd4j.create(new int[][] { input }); + INDArray prediction = model.output(boardInput); + + int move = pickLegalMove(prediction, reversi); + return new Move(move, reversi.getCurrentPlayer()); + } + + private int pickLegalMove(INDArray prediction, Reversi reversi){ + double[] probs = prediction.toDoubleVector(); + Move[] legalMoves = reversi.getLegalMoves(); + + if (legalMoves.length == 0) return -1; + + int bestMove = legalMoves[0].position(); + double bestVal = probs[bestMove]; + + for (Move move : legalMoves){ + if (probs[move.position()] > bestVal){ + bestMove = move.position(); + bestVal = probs[bestMove]; + } + } + return bestMove; + } +} diff --git a/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java b/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java new file mode 100644 index 0000000..109dcc9 --- /dev/null +++ b/game/src/main/java/org/toop/game/reversi/ReversiAISimple.java @@ -0,0 +1,57 @@ +package org.toop.game.reversi; + +import org.toop.game.AI; +import org.toop.game.records.Move; + +import java.util.Arrays; + +public class ReversiAISimple extends AI { + + @Override + public Move findBestMove(Reversi game, int depth) { + //IO.println("****START FIND BEST MOVE****"); + + Move[] moves = game.getLegalMoves(); + + + //game.printBoard(); + //IO.println("Legal moves: " + Arrays.toString(moves)); + + Move bestMove; + Move bestMoveScore = moves[0]; + Move bestMoveOptions = moves[0]; + int bestScore = -1; + int bestOptions = -1; + for (Move move : moves){ + int numOpt = getNumberOfOptions(game, move); + if (numOpt > bestOptions) { + bestOptions = numOpt; + bestMoveOptions = move; + } + int numSco = getScore(game, move); + if (numSco > bestScore) { + bestScore = numSco; + bestMoveScore = move; + } + + //IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco); + } + if (bestScore > bestOptions) { + bestMove = bestMoveScore; + } + else{ + bestMove = bestMoveOptions; + } + return bestMove; + } + + private int getNumberOfOptions(Reversi game, Move move){ + Reversi copy = new Reversi(game); + copy.play(move); + return copy.getLegalMoves().length; + } + + private int getScore(Reversi game, Move move){ + return game.getFlipsForPotentialMove(move).length; + } +} diff --git a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java index f2586a5..4f2c80f 100644 --- a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java +++ b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java @@ -8,17 +8,24 @@ import org.toop.game.enumerators.GameState; import org.toop.game.records.Move; import org.toop.game.reversi.Reversi; import org.toop.game.reversi.ReversiAI; +import org.toop.game.reversi.ReversiAIML; +import org.toop.game.reversi.ReversiAISimple; import static org.junit.jupiter.api.Assertions.*; class ReversiTest { private Reversi game; private ReversiAI ai; + private ReversiAIML aiml; + private ReversiAISimple aiSimple; @BeforeEach void setup() { game = new Reversi(); ai = new ReversiAI(); + aiml = new ReversiAIML(); + aiSimple = new ReversiAISimple(); + } @@ -190,4 +197,35 @@ class ReversiTest { } return false; } + + @Test + void testAIvsAIML(){ + IO.println("Testing AI simple ..."); + int totalGames = 5000; + int p1wins = 0; + int p2wins = 0; + int draws = 0; + for (int i = 0; i < totalGames; i++) { + game = new Reversi(); + while (!game.isGameOver()) { + char curr = game.getCurrentPlayer(); + if (curr == 'W') { + game.play(ai.findBestMove(game,5)); + } + else { + game.play(ai.findBestMove(game,5)); + } + } + int winner = game.getWinner(); + if (winner == 1) { + p1wins++; + }else if (winner == 2) { + p2wins++; + } + else{ + draws++; + } + } + IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws); + } }