diff --git a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java index f380bef..4cb60f9 100644 --- a/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java +++ b/game/src/main/java/org/toop/game/games/reversi/BitboardReversi.java @@ -167,4 +167,13 @@ public class BitboardReversi extends BitboardGame { private long shift(long bit, int shift, long mask) { return shift > 0 ? (bit << shift) & mask : (bit >>> -shift) & mask; } + + public boolean isGameOver(){ + BitboardReversi copy = this.deepCopy(); + if (copy.getLegalMoves() == 0){ + nextTurn(); + return copy.getLegalMoves() == 0; + } + return false; + } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/games/reversi/ReversiAIML.java b/game/src/main/java/org/toop/game/games/reversi/ReversiAIML.java deleted file mode 100644 index 6c46905..0000000 --- a/game/src/main/java/org/toop/game/games/reversi/ReversiAIML.java +++ /dev/null @@ -1,57 +0,0 @@ -package org.toop.game.games.reversi; - -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.io.IOException; -import java.io.InputStream; - -import static java.lang.Math.random; - -public class ReversiAIML extends AbstractAI { - - MultiLayerNetwork model; - - public ReversiAIML() { - InputStream is = getClass().getResourceAsStream("/reversi-model.zip"); - try { - assert is != null; - model = ModelSerializer.restoreMultiLayerNetwork(is); - } catch (IOException e) {} - } - - private int pickLegalMove(INDArray prediction, ReversiR reversi) { - double[] logits = prediction.toDoubleVector(); - int[] legalMoves = reversi.getLegalMoves(); - - if (legalMoves.length == 0) return -1; - - int bestMove = legalMoves[0]; - double bestVal = logits[bestMove]; - - if (random() < 0.01){ - return legalMoves[(int)(random()*legalMoves.length-.5)]; - } - for (int move : legalMoves) { - if (logits[move] > bestVal) { - bestMove = move; - bestVal = logits[move]; - } - } - return bestMove; - } - - @Override - public int getMove(ReversiR game) { - int[] input = game.getBoard(); - - INDArray boardInput = Nd4j.create(new int[][] { input }); - INDArray prediction = model.output(boardInput); - - int move = pickLegalMove(prediction,game); - return move; - } -} diff --git a/game/src/main/java/org/toop/game/games/reversi/ReversiAISimple.java b/game/src/main/java/org/toop/game/games/reversi/ReversiAISimple.java deleted file mode 100644 index d744b45..0000000 --- a/game/src/main/java/org/toop/game/games/reversi/ReversiAISimple.java +++ /dev/null @@ -1,58 +0,0 @@ -package org.toop.game.games.reversi; - -import org.toop.framework.gameFramework.model.player.AbstractAI; - -import java.awt.*; - -public class ReversiAISimple extends AbstractAI { - - - private int getNumberOfOptions(ReversiR game, int move){ - ReversiR copy = game.deepCopy(); - copy.play(move); - return copy.getLegalMoves().length; - } - - private int getScore(ReversiR game, int move){ - return game.getFlipsForPotentialMove(new Point(move%game.getColumnSize(),move/game.getRowSize()),game.getCurrentTurn()).length; - } - - @Override - public int getMove(ReversiR game) { - - int[] moves = game.getLegalMoves(); - - int bestMove; - int bestMoveScore = moves[0]; - int bestMoveOptions = moves[0]; - int bestScore = -1; - int bestOptions = -1; - for (int move : moves){ - int numOpt = getNumberOfOptions(game, move); - if (numOpt > bestOptions) { - bestOptions = numOpt; - bestMoveOptions = move; - } - int numSco = getScore(game, move); - if (numSco > bestScore) { - bestScore = numSco; - bestMoveScore = move; - } - if (numSco == bestScore || numOpt == bestOptions) { - if (Math.random() < 0.5) { - bestMoveOptions = move; - bestMoveScore = move; - } - } - - //IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco); - } - if (bestScore > bestOptions) { - bestMove = bestMoveScore; - } - else{ - bestMove = bestMoveOptions; - } - return bestMove; - } -} diff --git a/game/src/main/java/org/toop/game/machinelearning/NeuralNetwork.java b/game/src/main/java/org/toop/game/machinelearning/NeuralNetwork.java index e00c88e..4c6893d 100644 --- a/game/src/main/java/org/toop/game/machinelearning/NeuralNetwork.java +++ b/game/src/main/java/org/toop/game/machinelearning/NeuralNetwork.java @@ -17,10 +17,11 @@ import org.toop.framework.gameFramework.GameState; import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.player.AbstractAI; import org.toop.framework.gameFramework.model.player.Player; -import org.toop.game.games.reversi.ReversiAIR; -import org.toop.game.games.reversi.ReversiR; -import org.toop.game.games.reversi.ReversiAIML; -import org.toop.game.games.reversi.ReversiAISimple; +import org.toop.game.games.reversi.BitboardReversi; +import org.toop.game.players.ArtificialPlayer; +import org.toop.game.players.ai.MiniMaxAI; +import org.toop.game.players.ai.RandomAI; +import org.toop.game.players.ai.ReversiAIML; import java.io.File; import java.io.IOException; @@ -34,15 +35,18 @@ public class NeuralNetwork { private MultiLayerConfiguration conf; private MultiLayerNetwork model; - private AbstractAI opponentAI; - private AbstractAI opponentRand = new ReversiAIR(); - private AbstractAI opponentSimple = new ReversiAISimple(); - private AbstractAI opponentAIML = new ReversiAIML(); + private AbstractAI opponentAI; + private AbstractAI opponentMM = new MiniMaxAI<>(6); + private AbstractAI opponentRand = new RandomAI<>(); + private AbstractAI opponentAIML = new ReversiAIML<>(); + private Player[] playerSet = new Player[4]; public NeuralNetwork() {} public void init(){ + initPlayers(); + conf = new NeuralNetConfiguration.Builder() .updater(new Adam(0.001)) .weightInit(WeightInit.XAVIER) //todo understand @@ -70,6 +74,12 @@ public class NeuralNetwork { saveModel(); } + public void initPlayers(){ + playerSet[0] = new ArtificialPlayer<>(new MiniMaxAI(6),"MiniMaxAI"); + playerSet[1] = new ArtificialPlayer<>(new RandomAI(),"RandomAI"); + playerSet[2] = new ArtificialPlayer<>(new ReversiAIML(),"MachineLearningAI"); + } + public void saveModel(){ File modelFile = new File("reversi-model.zip"); try { @@ -92,11 +102,13 @@ public class NeuralNetwork { int totalGames = 5000; double epsilon = 0.05; + + long start = System.nanoTime(); for (int game = 0; game gameHistory = new ArrayList<>(); PlayResult state = new PlayResult(GameState.NORMAL,reversi.getCurrentTurn()); @@ -105,14 +117,14 @@ public class NeuralNetwork { while (state.state() != GameState.DRAW && state.state() != GameState.WIN){ int curr = reversi.getCurrentTurn(); - int move; + long move; if (curr == modelPlayer) { - int[] input = reversi.getBoard(); + long[] input = reversi.getBoard(); if (Math.random() < epsilon) { - int[] moves = reversi.getLegalMoves(); - move = moves[(int) (Math.random() * moves.length - .5f)]; + long moves = reversi.getLegalMoves(); + move = (long) (Math.random() * Long.bitCount(moves) - .5f); } else { - INDArray boardInput = Nd4j.create(new int[][]{input}); + INDArray boardInput = Nd4j.create(new long[][]{input}); INDArray prediction = model.output(boardInput); int location = pickLegalMove(prediction, reversi); @@ -126,11 +138,11 @@ public class NeuralNetwork { } //IO.println(model.params()); - ReversiR.Score score = reversi.getScore(); - int scoreDif = abs(score.player1Score() - score.player2Score()); - if (score.player1Score() > score.player2Score()){ + BitboardReversi.Score score = reversi.getScore(); + int scoreDif = abs(score.black() - score.white()); + if (score.black() > score.white()){ reward = 1 + ((scoreDif / 64.0) * 0.5); - }else if (score.player1Score() < score.player2Score()){ + }else if (score.black() < score.white()){ reward = -1 - ((scoreDif / 64.0) * 0.5); }else{ reward = 0; @@ -156,28 +168,46 @@ public class NeuralNetwork { } - private int pickLegalMove(INDArray prediction, ReversiR reversi){ - double[] probs = prediction.toDoubleVector(); - int[] legalMoves = reversi.getLegalMoves(); + private int pickLegalMove(INDArray prediction, BitboardReversi reversi) { + double[] logits = prediction.toDoubleVector(); + long legalMoves = reversi.getLegalMoves(); - if (legalMoves.length == 0) return -1; - - int bestMove = legalMoves[0]; - double bestVal = probs[bestMove]; - - for (int move : legalMoves){ - if (probs[move] > bestVal){ - bestMove = move; - bestVal = probs[bestMove]; - } + if (legalMoves == 0L) { + return -1; } + + if (Math.random() < 0.01) { + int randomIndex = (int) (Math.random() * Long.bitCount(legalMoves)); + long moves = legalMoves; + for (int i = 0; i < randomIndex; i++) { + moves &= moves - 1; + } + return Long.numberOfTrailingZeros(moves); + } + + int bestMove = -1; + double bestVal = Double.NEGATIVE_INFINITY; + + long moves = legalMoves; + while (moves != 0L) { + int move = Long.numberOfTrailingZeros(moves); + double value = logits[move]; + + if (value > bestVal) { + bestVal = value; + bestMove = move; + } + + moves &= moves - 1; + } + return bestMove; } - private AbstractAI getOpponentAI(){ + private AbstractAI getOpponentAI(){ return switch ((int) (Math.random() * 4)) { case 0 -> opponentRand; - case 1 -> opponentSimple; + case 1 -> opponentMM; case 2 -> opponentAIML; default -> opponentRand; }; @@ -188,7 +218,7 @@ public class NeuralNetwork { output[step.action] = reward; DataSet ds = new DataSet( - Nd4j.create(new int[][] { step.state }), + Nd4j.create(new long[][] { step.state }), Nd4j.create(new double[][] { output }) ); diff --git a/game/src/main/java/org/toop/game/machinelearning/StateAction.java b/game/src/main/java/org/toop/game/machinelearning/StateAction.java index 561230b..0200f10 100644 --- a/game/src/main/java/org/toop/game/machinelearning/StateAction.java +++ b/game/src/main/java/org/toop/game/machinelearning/StateAction.java @@ -1,9 +1,9 @@ package org.toop.game.machinelearning; public class StateAction { - int[] state; + long[] state; int action; - public StateAction(int[] state, int action) { + public StateAction(long[] state, int action) { this.state = state; this.action = action; } diff --git a/game/src/main/java/org/toop/game/players/ai/ReversiAIML.java b/game/src/main/java/org/toop/game/players/ai/ReversiAIML.java new file mode 100644 index 0000000..d6a97af --- /dev/null +++ b/game/src/main/java/org/toop/game/players/ai/ReversiAIML.java @@ -0,0 +1,80 @@ +package org.toop.game.players.ai; + +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.toop.framework.gameFramework.model.game.TurnBasedGame; +import org.toop.framework.gameFramework.model.player.AI; +import org.toop.framework.gameFramework.model.player.AbstractAI; +import org.toop.game.games.reversi.BitboardReversi; + +import java.io.IOException; +import java.io.InputStream; + +import static java.lang.Math.random; + +public class ReversiAIML> extends AbstractAI { + + MultiLayerNetwork model; + + public ReversiAIML() { + InputStream is = getClass().getResourceAsStream("/reversi-model.zip"); + try { + assert is != null; + model = ModelSerializer.restoreMultiLayerNetwork(is); + } catch (IOException e) {} + } + + private int pickLegalMove(INDArray prediction, BitboardReversi reversi) { + double[] logits = prediction.toDoubleVector(); + long legalMoves = reversi.getLegalMoves(); + + if (legalMoves == 0L) { + return -1; + } + + if (Math.random() < 0.01) { + int randomIndex = (int) (Math.random() * Long.bitCount(legalMoves)); + long moves = legalMoves; + for (int i = 0; i < randomIndex; i++) { + moves &= moves - 1; + } + return Long.numberOfTrailingZeros(moves); + } + + int bestMove = -1; + double bestVal = Double.NEGATIVE_INFINITY; + + long moves = legalMoves; + while (moves != 0L) { + int move = Long.numberOfTrailingZeros(moves); + double value = logits[move]; + + if (value > bestVal) { + bestVal = value; + bestMove = move; + } + + moves &= moves - 1; + } + + return bestMove; + } + + @Override + public long getMove(T game) { + long[] input = game.getBoard(); + + INDArray boardInput = Nd4j.create(new long[][] { input }); + INDArray prediction = model.output(boardInput); + + int move = pickLegalMove(prediction,(BitboardReversi) game); + return move; + } + + @Override + public ReversiAIML deepCopy() { + return new ReversiAIML(); + } +} diff --git a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java index cd83b63..1a807a6 100644 --- a/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java +++ b/game/src/test/java/org/toop/game/tictactoe/ReversiTest.java @@ -17,7 +17,7 @@ import org.toop.game.games.reversi.ReversiR; import org.toop.game.records.Move; import org.toop.game.reversi.Reversi; import org.toop.game.reversi.ReversiAI; -import org.toop.game.games.reversi.ReversiAIML; +import org.toop.game.players.ai.ReversiAIML; import org.toop.game.games.reversi.ReversiAISimple; import static org.junit.jupiter.api.Assertions.*; diff --git a/game/src/test/java/org/toop/game/tictactoe/TestReversi.java b/game/src/test/java/org/toop/game/tictactoe/TestReversi.java new file mode 100644 index 0000000..269bd0c --- /dev/null +++ b/game/src/test/java/org/toop/game/tictactoe/TestReversi.java @@ -0,0 +1,78 @@ +package org.toop.game.tictactoe; + +import java.util.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.toop.framework.gameFramework.model.player.Player; +import org.toop.game.games.reversi.BitboardReversi; +import org.toop.game.players.ArtificialPlayer; +import org.toop.game.players.ai.MiniMaxAI; +import org.toop.game.players.ai.RandomAI; + +import static org.junit.jupiter.api.Assertions.*; + +public class TestReversi { + private BitboardReversi game; + private Player[] players; + + @BeforeEach + void setup(){ + players = new Player[2]; + players[0] = new ArtificialPlayer(new RandomAI(),"randomAI"); + players[1] = new ArtificialPlayer(new MiniMaxAI(10),"miniMaxAI"); + game = new BitboardReversi(players); + + } + + @Test + void testCorrectStartPiecesPlaced() { + assertNotNull(game); + long[] board = game.getBoard(); + IO.println(Long.toBinaryString(board[0])); + IO.println(Long.toBinaryString(board[1])); + long black = board[0]; + long white = board[1]; + assertEquals(1L, ((white >>> 27) & 1L)); //checks if the 27-shifted long has a 1 bit + assertEquals(1L, ((black >>> 28) & 1L)); + assertEquals(1L, ((black >>> 35) & 1L)); + assertEquals(1L, ((white >>> 36) & 1L)); + } + + @Test + void testPlayGames(){ + int totalGames = 1; + long start = System.nanoTime(); + long midtime = System.nanoTime(); + int p1wins = 0; + int p2wins = 0; + int draws = 0; + + for (int i = 0; i < totalGames; i++){ + game = new BitboardReversi(players); + while(!game.isGameOver()){ + midtime = System.nanoTime(); + int currentTurn = game.getCurrentTurn(); + long move = players[currentTurn].getMove(game.deepCopy()); + game.play(move); + IO.println(System.nanoTime() - midtime); + } + switch (game.getWinner()){ + case 0: + p1wins++; + break; + case 1: + p2wins++; + break; + case -1: + draws++; + break; + + } + } + System.out.println(System.nanoTime() - start); + IO.println(p1wins + " " + p2wins + " " + draws); + assertEquals(totalGames, p1wins + p2wins + draws); + IO.println("p1 wr: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames); + } +}