mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
Compare commits
1 Commits
223-create
...
7e913ff50f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e913ff50f |
@@ -2,6 +2,7 @@ package org.toop;
|
|||||||
|
|
||||||
import org.toop.app.App;
|
import org.toop.app.App;
|
||||||
import org.toop.framework.audio.*;
|
import org.toop.framework.audio.*;
|
||||||
|
import org.toop.framework.machinelearning.NeuralNetwork;
|
||||||
import org.toop.framework.networking.NetworkingClientEventListener;
|
import org.toop.framework.networking.NetworkingClientEventListener;
|
||||||
import org.toop.framework.networking.NetworkingClientManager;
|
import org.toop.framework.networking.NetworkingClientManager;
|
||||||
import org.toop.framework.resource.ResourceLoader;
|
import org.toop.framework.resource.ResourceLoader;
|
||||||
@@ -35,6 +36,10 @@ public final class Main {
|
|||||||
).initListeners("medium-button-click.wav");
|
).initListeners("medium-button-click.wav");
|
||||||
|
|
||||||
}).start();
|
}).start();
|
||||||
|
|
||||||
|
|
||||||
|
//NeuralNetwork nn = new NeuralNetwork();
|
||||||
|
//nn.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import org.toop.game.reversi.ReversiAI;
|
|||||||
|
|
||||||
import javafx.geometry.Pos;
|
import javafx.geometry.Pos;
|
||||||
import javafx.scene.paint.Color;
|
import javafx.scene.paint.Color;
|
||||||
|
import org.toop.game.reversi.ReversiAIML;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.BlockingQueue;
|
||||||
@@ -30,7 +31,7 @@ public final class ReversiGame {
|
|||||||
private final BlockingQueue<Move> moveQueue;
|
private final BlockingQueue<Move> moveQueue;
|
||||||
|
|
||||||
private final Reversi game;
|
private final Reversi game;
|
||||||
private final ReversiAI ai;
|
private final ReversiAIML ai;
|
||||||
|
|
||||||
private final GameView primary;
|
private final GameView primary;
|
||||||
private final ReversiCanvas canvas;
|
private final ReversiCanvas canvas;
|
||||||
@@ -46,7 +47,7 @@ public final class ReversiGame {
|
|||||||
moveQueue = new LinkedBlockingQueue<>();
|
moveQueue = new LinkedBlockingQueue<>();
|
||||||
|
|
||||||
game = new Reversi();
|
game = new Reversi();
|
||||||
ai = new ReversiAI();
|
ai = new ReversiAIML();
|
||||||
|
|
||||||
isRunning = new AtomicBoolean(true);
|
isRunning = new AtomicBoolean(true);
|
||||||
isPaused = new AtomicBoolean(false);
|
isPaused = new AtomicBoolean(false);
|
||||||
@@ -324,7 +325,7 @@ public final class ReversiGame {
|
|||||||
if (isLegalMove) {
|
if (isLegalMove) {
|
||||||
moves = game.getFlipsForPotentialMove(
|
moves = game.getFlipsForPotentialMove(
|
||||||
new Point(cellEntered%game.getColumnSize(),cellEntered/game.getRowSize()),
|
new Point(cellEntered%game.getColumnSize(),cellEntered/game.getRowSize()),
|
||||||
game.getCurrentPlayer());
|
game.getCurrentPlayer(), game.getCurrentPlayer() == 'B'?'W':'B',game.makeBoardAGrid());
|
||||||
}
|
}
|
||||||
canvas.drawHighlightDots(moves);
|
canvas.drawHighlightDots(moves);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,19 @@
|
|||||||
<artifactId>error_prone_annotations</artifactId>
|
<artifactId>error_prone_annotations</artifactId>
|
||||||
<version>2.42.0</version>
|
<version>2.42.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-nn</artifactId>
|
||||||
|
<version>1.0.0-M2.1</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.toop</groupId>
|
||||||
|
<artifactId>game</artifactId>
|
||||||
|
<version>0.1</version>
|
||||||
|
<scope>compile</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
<build>
|
<build>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
|||||||
@@ -0,0 +1,182 @@
|
|||||||
|
package org.toop.framework.machinelearning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
import org.toop.game.AI;
|
||||||
|
import org.toop.game.enumerators.GameState;
|
||||||
|
import org.toop.game.records.Move;
|
||||||
|
import org.toop.game.reversi.Reversi;
|
||||||
|
import org.toop.game.reversi.ReversiAI;
|
||||||
|
import org.toop.game.reversi.ReversiAISimple;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.lang.Math.abs;
|
||||||
|
|
||||||
|
public class NeuralNetwork {
|
||||||
|
|
||||||
|
private MultiLayerConfiguration conf;
|
||||||
|
private MultiLayerNetwork model;
|
||||||
|
private ReversiAI reversiAI = new ReversiAI();
|
||||||
|
private AI<Reversi> opponentRand = new ReversiAI();
|
||||||
|
private AI<Reversi> opponentSimple = new ReversiAISimple();
|
||||||
|
|
||||||
|
|
||||||
|
public NeuralNetwork() {}
|
||||||
|
|
||||||
|
public void init(){
|
||||||
|
conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.updater(new Adam(0.001))
|
||||||
|
.weightInit(WeightInit.XAVIER) //todo understand
|
||||||
|
.list()
|
||||||
|
.layer(new DenseLayer.Builder()
|
||||||
|
.nIn(64)
|
||||||
|
.nOut(128)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.nIn(128)
|
||||||
|
.nOut(64)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
model = new MultiLayerNetwork(conf);
|
||||||
|
IO.println(model.params());
|
||||||
|
loadModel();
|
||||||
|
IO.println(model.params());
|
||||||
|
model.init();
|
||||||
|
IO.println(model.summary());
|
||||||
|
|
||||||
|
model.setLearningRate(0.001);
|
||||||
|
trainingLoop();
|
||||||
|
saveModel();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveModel(){
|
||||||
|
File modelFile = new File("reversi-model.zip");
|
||||||
|
try {
|
||||||
|
ModelSerializer.writeModel(model, modelFile, true);
|
||||||
|
}catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadModel(){
|
||||||
|
File modelFile = new File("reversi-model.zip");
|
||||||
|
try {
|
||||||
|
model = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void trainingLoop(){
|
||||||
|
int totalGames = 5000;
|
||||||
|
double epsilon = 0.1;
|
||||||
|
|
||||||
|
long start = System.nanoTime();
|
||||||
|
|
||||||
|
for (int game = 0; game<totalGames; game++){
|
||||||
|
Reversi reversi = new Reversi();
|
||||||
|
List<StateAction> gameHistory = new ArrayList<>();
|
||||||
|
GameState state = GameState.NORMAL;
|
||||||
|
|
||||||
|
double reward = 0;
|
||||||
|
|
||||||
|
while (state != GameState.DRAW && state != GameState.WIN){
|
||||||
|
char curr = reversi.getCurrentPlayer();
|
||||||
|
Move move;
|
||||||
|
if (curr == 'B') {
|
||||||
|
int[] input = reversi.getBoardInt();
|
||||||
|
if (Math.random() < epsilon) {
|
||||||
|
Move[] moves = reversi.getLegalMoves();
|
||||||
|
move = moves[(int) (Math.random() * moves.length - .5f)];
|
||||||
|
} else {
|
||||||
|
INDArray boardInput = Nd4j.create(new int[][]{input});
|
||||||
|
INDArray prediction = model.output(boardInput);
|
||||||
|
|
||||||
|
int location = pickLegalMove(prediction, reversi);
|
||||||
|
gameHistory.add(new StateAction(input, location));
|
||||||
|
move = new Move(location, reversi.getCurrentPlayer());
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
move = reversiAI.findBestMove(reversi,5);
|
||||||
|
}
|
||||||
|
state = reversi.play(move);
|
||||||
|
}
|
||||||
|
|
||||||
|
//IO.println(model.params());
|
||||||
|
Reversi.Score score = reversi.getScore();
|
||||||
|
int scoreDif = abs(score.player1Score() - score.player2Score());
|
||||||
|
if (score.player1Score() > score.player2Score()){
|
||||||
|
reward = 1 + ((scoreDif / 64.0) * 0.5);
|
||||||
|
}else if (score.player1Score() < score.player2Score()){
|
||||||
|
reward = -1 - ((scoreDif / 64.0) * 0.5);
|
||||||
|
}else{
|
||||||
|
reward = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (StateAction step : gameHistory){
|
||||||
|
trainFromHistory(step, reward);
|
||||||
|
}
|
||||||
|
|
||||||
|
//IO.println("Wr: " + (double)p1wins/(game+1) + " draws: " + draws);
|
||||||
|
if(game % 100 == 0){
|
||||||
|
IO.println("Completed game " + game + " | Reward: " + reward);
|
||||||
|
//IO.println(Arrays.toString(reversi.getBoardDouble()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
long end = System.nanoTime();
|
||||||
|
IO.println((end-start));
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isInCorner(Move move){
|
||||||
|
return move.position() == 0 || move.position() == 7 || move.position() == 56 || move.position() == 63;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int pickLegalMove(INDArray prediction, Reversi reversi){
|
||||||
|
double[] probs = prediction.toDoubleVector();
|
||||||
|
Move[] legalMoves = reversi.getLegalMoves();
|
||||||
|
|
||||||
|
if (legalMoves.length == 0) return -1;
|
||||||
|
|
||||||
|
int bestMove = legalMoves[0].position();
|
||||||
|
double bestVal = probs[bestMove];
|
||||||
|
|
||||||
|
for (Move move : legalMoves){
|
||||||
|
if (probs[move.position()] > bestVal){
|
||||||
|
bestMove = move.position();
|
||||||
|
bestVal = probs[bestMove];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestMove;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void trainFromHistory(StateAction step, double reward){
|
||||||
|
double[] output = new double[64];
|
||||||
|
output[step.action] = reward;
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(
|
||||||
|
Nd4j.create(new int[][] { step.state }),
|
||||||
|
Nd4j.create(new double[][] { output })
|
||||||
|
);
|
||||||
|
|
||||||
|
model.fit(ds);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package org.toop.framework.machinelearning;
|
||||||
|
|
||||||
|
public class StateAction {
|
||||||
|
int[] state;
|
||||||
|
int action;
|
||||||
|
public StateAction(int[] state, int action) {
|
||||||
|
this.state = state;
|
||||||
|
this.action = action;
|
||||||
|
}
|
||||||
|
}
|
||||||
10
game/pom.xml
10
game/pom.xml
@@ -99,6 +99,16 @@
|
|||||||
<artifactId>error_prone_annotations</artifactId>
|
<artifactId>error_prone_annotations</artifactId>
|
||||||
<version>2.42.0</version>
|
<version>2.42.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-core</artifactId>
|
||||||
|
<version>1.0.0-M2.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native-platform</artifactId>
|
||||||
|
<version>1.0.0-M2.1</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
private int movesTaken;
|
private int movesTaken;
|
||||||
private Set<Point> filledCells = new HashSet<>();
|
private Set<Point> filledCells = new HashSet<>();
|
||||||
private Move[] mostRecentlyFlippedPieces;
|
private Move[] mostRecentlyFlippedPieces;
|
||||||
|
private char[][] cachedBoard;
|
||||||
|
|
||||||
public record Score(int player1Score, int player2Score) {}
|
public record Score(int player1Score, int player2Score) {}
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
this.setBoard(new Move(35, 'B'));
|
this.setBoard(new Move(35, 'B'));
|
||||||
this.setBoard(new Move(36, 'W'));
|
this.setBoard(new Move(36, 'W'));
|
||||||
updateFilledCellsSet();
|
updateFilledCellsSet();
|
||||||
|
cachedBoard = makeBoardAGrid();
|
||||||
}
|
}
|
||||||
private void updateFilledCellsSet() {
|
private void updateFilledCellsSet() {
|
||||||
for (int i = 0; i < 64; i++) {
|
for (int i = 0; i < 64; i++) {
|
||||||
@@ -49,11 +51,13 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
@Override
|
@Override
|
||||||
public Move[] getLegalMoves() {
|
public Move[] getLegalMoves() {
|
||||||
final ArrayList<Move> legalMoves = new ArrayList<>();
|
final ArrayList<Move> legalMoves = new ArrayList<>();
|
||||||
char[][] boardGrid = makeBoardAGrid();
|
char[][] boardGrid = cachedBoard;
|
||||||
char currentPlayer = (this.getCurrentTurn()==0) ? 'B' : 'W';
|
char currentPlayer = (this.getCurrentTurn()==0) ? 'B' : 'W';
|
||||||
Set<Point> adjCell = getAdjacentCells(boardGrid);
|
char opponent = (currentPlayer=='W') ? 'B' : 'W';
|
||||||
|
|
||||||
|
Set<Point> adjCell = getAdjacentCells(boardGrid, opponent);
|
||||||
for (Point point : adjCell){
|
for (Point point : adjCell){
|
||||||
Move[] moves = getFlipsForPotentialMove(point,currentPlayer);
|
Move[] moves = getFlipsForPotentialMove(point, currentPlayer, opponent, boardGrid);
|
||||||
int score = moves.length;
|
int score = moves.length;
|
||||||
if (score > 0){
|
if (score > 0){
|
||||||
legalMoves.add(new Move(point.x + point.y * this.getRowSize(), currentPlayer));
|
legalMoves.add(new Move(point.x + point.y * this.getRowSize(), currentPlayer));
|
||||||
@@ -62,18 +66,20 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
return legalMoves.toArray(new Move[0]);
|
return legalMoves.toArray(new Move[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<Point> getAdjacentCells(char[][] boardGrid) {
|
private Set<Point> getAdjacentCells(char[][] boardGrid, char opponent) {
|
||||||
Set<Point> possibleCells = new HashSet<>();
|
Set<Point> possibleCells = new HashSet<>();
|
||||||
for (Point point : filledCells) { //for every filled cell
|
for (Point point : filledCells) { //for every filled cell
|
||||||
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++){ //check adjacent cells
|
if (boardGrid[point.x][point.y] == opponent) {
|
||||||
for (int deltaRow = -1; deltaRow <= 1; deltaRow++){ //orthogonally and diagonally
|
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //check adjacent cells
|
||||||
int newX = point.x + deltaColumn, newY = point.y + deltaRow;
|
for (int deltaRow = -1; deltaRow <= 1; deltaRow++) { //orthogonally and diagonally
|
||||||
if (deltaColumn == 0 && deltaRow == 0 //continue if out of bounds
|
int newX = point.x + deltaColumn, newY = point.y + deltaRow;
|
||||||
|| !isOnBoard(newX, newY)) {
|
if (deltaColumn == 0 && deltaRow == 0 //continue if out of bounds
|
||||||
continue;
|
|| !isOnBoard(newX, newY)) {
|
||||||
}
|
continue;
|
||||||
if (boardGrid[newY][newX] == EMPTY) { //check if the cell is empty
|
}
|
||||||
possibleCells.add(new Point(newX, newY)); //and then add it to the set of possible moves
|
if (boardGrid[newY][newX] == EMPTY) { //check if the cell is empty
|
||||||
|
possibleCells.add(new Point(newX, newY)); //and then add it to the set of possible moves
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,14 +87,14 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
return possibleCells;
|
return possibleCells;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Move[] getFlipsForPotentialMove(Point point, char currentPlayer) {
|
public Move[] getFlipsForPotentialMove(Point point, char currentPlayer, char opponent, char[][] boardGrid) {
|
||||||
final ArrayList<Move> movesToFlip = new ArrayList<>();
|
final ArrayList<Move> movesToFlip = new ArrayList<>();
|
||||||
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //for all directions
|
for (int deltaColumn = -1; deltaColumn <= 1; deltaColumn++) { //for all directions
|
||||||
for (int deltaRow = -1; deltaRow <= 1; deltaRow++) {
|
for (int deltaRow = -1; deltaRow <= 1; deltaRow++) {
|
||||||
if (deltaColumn == 0 && deltaRow == 0){
|
if (deltaColumn == 0 && deltaRow == 0){
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Move[] moves = getFlipsInDirection(point,makeBoardAGrid(),currentPlayer,deltaColumn,deltaRow);
|
Move[] moves = getFlipsInDirection(point, boardGrid, currentPlayer, opponent, deltaColumn, deltaRow);
|
||||||
if (moves != null) { //getFlipsInDirection
|
if (moves != null) { //getFlipsInDirection
|
||||||
movesToFlip.addAll(Arrays.asList(moves));
|
movesToFlip.addAll(Arrays.asList(moves));
|
||||||
}
|
}
|
||||||
@@ -97,8 +103,14 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
return movesToFlip.toArray(new Move[0]);
|
return movesToFlip.toArray(new Move[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, int dirX, int dirY) {
|
public Move[] getFlipsForPotentialMove(Move move) {
|
||||||
char opponent = getOpponent(currentPlayer);
|
char curr = getCurrentPlayer();
|
||||||
|
char opp = getOpponent(curr);
|
||||||
|
Point point = new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize());
|
||||||
|
return getFlipsForPotentialMove(point, curr, opp, cachedBoard);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Move[] getFlipsInDirection(Point point, char[][] boardGrid, char currentPlayer, char opponent, int dirX, int dirY) {
|
||||||
final ArrayList<Move> movesToFlip = new ArrayList<>();
|
final ArrayList<Move> movesToFlip = new ArrayList<>();
|
||||||
int x = point.x + dirX;
|
int x = point.x + dirX;
|
||||||
int y = point.y + dirY;
|
int y = point.y + dirY;
|
||||||
@@ -123,7 +135,7 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
return x >= 0 && x < this.getColumnSize() && y >= 0 && y < this.getRowSize();
|
return x >= 0 && x < this.getColumnSize() && y >= 0 && y < this.getRowSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
private char[][] makeBoardAGrid() {
|
public char[][] makeBoardAGrid() {
|
||||||
char[][] boardGrid = new char[this.getRowSize()][this.getColumnSize()];
|
char[][] boardGrid = new char[this.getRowSize()][this.getColumnSize()];
|
||||||
for (int i = 0; i < 64; i++) {
|
for (int i = 0; i < 64; i++) {
|
||||||
boardGrid[i / this.getRowSize()][i % this.getColumnSize()] = this.getBoard()[i]; //boardGrid[y -> row] [x -> column]
|
boardGrid[i / this.getRowSize()][i % this.getColumnSize()] = this.getBoard()[i]; //boardGrid[y -> row] [x -> column]
|
||||||
@@ -133,6 +145,9 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public GameState play(Move move) {
|
public GameState play(Move move) {
|
||||||
|
if (cachedBoard == null) {
|
||||||
|
cachedBoard = makeBoardAGrid();
|
||||||
|
}
|
||||||
Move[] legalMoves = getLegalMoves();
|
Move[] legalMoves = getLegalMoves();
|
||||||
boolean moveIsLegal = false;
|
boolean moveIsLegal = false;
|
||||||
for (Move legalMove : legalMoves) { //check if the move is legal
|
for (Move legalMove : legalMoves) { //check if the move is legal
|
||||||
@@ -145,13 +160,14 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value()),move);
|
Move[] moves = sortMovesFromCenter(getFlipsForPotentialMove(new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize()), move.value(),move.value() == 'B'? 'W': 'B',makeBoardAGrid()),move);
|
||||||
mostRecentlyFlippedPieces = moves;
|
mostRecentlyFlippedPieces = moves;
|
||||||
this.setBoard(move); //place the move on the board
|
this.setBoard(move); //place the move on the board
|
||||||
for (Move m : moves) {
|
for (Move m : moves) {
|
||||||
this.setBoard(m); //flip the correct pieces on the board
|
this.setBoard(m); //flip the correct pieces on the board
|
||||||
}
|
}
|
||||||
filledCells.add(new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize()));
|
filledCells.add(new Point(move.position() % this.getRowSize(), move.position() / this.getColumnSize()));
|
||||||
|
cachedBoard = makeBoardAGrid();
|
||||||
nextTurn();
|
nextTurn();
|
||||||
if (getLegalMoves().length == 0) { //skip the players turn when there are no legal moves
|
if (getLegalMoves().length == 0) { //skip the players turn when there are no legal moves
|
||||||
skipMyTurn();
|
skipMyTurn();
|
||||||
@@ -172,7 +188,7 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void skipMyTurn(){
|
private void skipMyTurn(){
|
||||||
IO.println("TURN " + getCurrentPlayer() + " SKIPPED");
|
//IO.println("TURN " + getCurrentPlayer() + " SKIPPED");
|
||||||
//TODO: notify user that a turn has been skipped
|
//TODO: notify user that a turn has been skipped
|
||||||
nextTurn();
|
nextTurn();
|
||||||
}
|
}
|
||||||
@@ -207,6 +223,32 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
}
|
}
|
||||||
return new Score(player1Score, player2Score);
|
return new Score(player1Score, player2Score);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isGameOver(){
|
||||||
|
Move[] legalMovesW = getLegalMoves();
|
||||||
|
nextTurn();
|
||||||
|
Move[] legalMovesB = getLegalMoves();
|
||||||
|
nextTurn();
|
||||||
|
if (legalMovesW.length + legalMovesB.length == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getWinner(){
|
||||||
|
if (!isGameOver()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
Score score = getScore();
|
||||||
|
if (score.player1Score() > score.player2Score()) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
else if (score.player1Score() < score.player2Score()) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
private Move[] sortMovesFromCenter(Move[] moves, Move center) { //sorts the pieces to be flipped for animation purposes
|
private Move[] sortMovesFromCenter(Move[] moves, Move center) { //sorts the pieces to be flipped for animation purposes
|
||||||
int centerX = center.position()%this.getColumnSize();
|
int centerX = center.position()%this.getColumnSize();
|
||||||
int centerY = center.position()/this.getRowSize();
|
int centerY = center.position()/this.getRowSize();
|
||||||
@@ -226,4 +268,34 @@ public final class Reversi extends TurnBasedGame {
|
|||||||
public Move[] getMostRecentlyFlippedPieces() {
|
public Move[] getMostRecentlyFlippedPieces() {
|
||||||
return mostRecentlyFlippedPieces;
|
return mostRecentlyFlippedPieces;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int[] getBoardInt(){
|
||||||
|
char[] input = getBoard();
|
||||||
|
int[] result = new int[input.length];
|
||||||
|
for (int i = 0; i < input.length; i++) {
|
||||||
|
switch (input[i]) {
|
||||||
|
case 'W':
|
||||||
|
result[i] = -1;
|
||||||
|
break;
|
||||||
|
case 'B':
|
||||||
|
result[i] = 1;
|
||||||
|
break;
|
||||||
|
case ' ':
|
||||||
|
default:
|
||||||
|
result[i] = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Point moveToPoint(Move move){
|
||||||
|
return new Point(move.position()%this.getColumnSize(),move.position()/this.getRowSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void printBoard(){
|
||||||
|
for (int row = 0; row < this.getRowSize(); row++) {
|
||||||
|
IO.println(Arrays.toString(cachedBoard[row]));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -4,6 +4,7 @@ import org.toop.game.AI;
|
|||||||
import org.toop.game.records.Move;
|
import org.toop.game.records.Move;
|
||||||
|
|
||||||
public final class ReversiAI extends AI<Reversi> {
|
public final class ReversiAI extends AI<Reversi> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Move findBestMove(Reversi game, int depth) {
|
public Move findBestMove(Reversi game, int depth) {
|
||||||
Move[] moves = game.getLegalMoves();
|
Move[] moves = game.getLegalMoves();
|
||||||
|
|||||||
52
game/src/main/java/org/toop/game/reversi/ReversiAIML.java
Normal file
52
game/src/main/java/org/toop/game/reversi/ReversiAIML.java
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package org.toop.game.reversi;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.toop.game.AI;
|
||||||
|
import org.toop.game.records.Move;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
public class ReversiAIML extends AI<Reversi>{
|
||||||
|
|
||||||
|
MultiLayerNetwork model;
|
||||||
|
|
||||||
|
public ReversiAIML() {
|
||||||
|
InputStream is = getClass().getResourceAsStream("/reversi-model.zip");
|
||||||
|
try {
|
||||||
|
assert is != null;
|
||||||
|
model = ModelSerializer.restoreMultiLayerNetwork(is);
|
||||||
|
} catch (IOException e) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Move findBestMove(Reversi reversi, int depth){
|
||||||
|
int[] input = reversi.getBoardInt();
|
||||||
|
|
||||||
|
INDArray boardInput = Nd4j.create(new int[][] { input });
|
||||||
|
INDArray prediction = model.output(boardInput);
|
||||||
|
|
||||||
|
int move = pickLegalMove(prediction, reversi);
|
||||||
|
return new Move(move, reversi.getCurrentPlayer());
|
||||||
|
}
|
||||||
|
|
||||||
|
private int pickLegalMove(INDArray prediction, Reversi reversi){
|
||||||
|
double[] probs = prediction.toDoubleVector();
|
||||||
|
Move[] legalMoves = reversi.getLegalMoves();
|
||||||
|
|
||||||
|
if (legalMoves.length == 0) return -1;
|
||||||
|
|
||||||
|
int bestMove = legalMoves[0].position();
|
||||||
|
double bestVal = probs[bestMove];
|
||||||
|
|
||||||
|
for (Move move : legalMoves){
|
||||||
|
if (probs[move.position()] > bestVal){
|
||||||
|
bestMove = move.position();
|
||||||
|
bestVal = probs[bestMove];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestMove;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package org.toop.game.reversi;
|
||||||
|
|
||||||
|
import org.toop.game.AI;
|
||||||
|
import org.toop.game.records.Move;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class ReversiAISimple extends AI<Reversi> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Move findBestMove(Reversi game, int depth) {
|
||||||
|
//IO.println("****START FIND BEST MOVE****");
|
||||||
|
|
||||||
|
Move[] moves = game.getLegalMoves();
|
||||||
|
|
||||||
|
|
||||||
|
//game.printBoard();
|
||||||
|
//IO.println("Legal moves: " + Arrays.toString(moves));
|
||||||
|
|
||||||
|
Move bestMove;
|
||||||
|
Move bestMoveScore = moves[0];
|
||||||
|
Move bestMoveOptions = moves[0];
|
||||||
|
int bestScore = -1;
|
||||||
|
int bestOptions = -1;
|
||||||
|
for (Move move : moves){
|
||||||
|
int numOpt = getNumberOfOptions(game, move);
|
||||||
|
if (numOpt > bestOptions) {
|
||||||
|
bestOptions = numOpt;
|
||||||
|
bestMoveOptions = move;
|
||||||
|
}
|
||||||
|
int numSco = getScore(game, move);
|
||||||
|
if (numSco > bestScore) {
|
||||||
|
bestScore = numSco;
|
||||||
|
bestMoveScore = move;
|
||||||
|
}
|
||||||
|
|
||||||
|
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
|
||||||
|
}
|
||||||
|
if (bestScore > bestOptions) {
|
||||||
|
bestMove = bestMoveScore;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
bestMove = bestMoveOptions;
|
||||||
|
}
|
||||||
|
return bestMove;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getNumberOfOptions(Reversi game, Move move){
|
||||||
|
Reversi copy = new Reversi(game);
|
||||||
|
copy.play(move);
|
||||||
|
return copy.getLegalMoves().length;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getScore(Reversi game, Move move){
|
||||||
|
return game.getFlipsForPotentialMove(move).length;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,17 +8,24 @@ import org.toop.game.enumerators.GameState;
|
|||||||
import org.toop.game.records.Move;
|
import org.toop.game.records.Move;
|
||||||
import org.toop.game.reversi.Reversi;
|
import org.toop.game.reversi.Reversi;
|
||||||
import org.toop.game.reversi.ReversiAI;
|
import org.toop.game.reversi.ReversiAI;
|
||||||
|
import org.toop.game.reversi.ReversiAIML;
|
||||||
|
import org.toop.game.reversi.ReversiAISimple;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
class ReversiTest {
|
class ReversiTest {
|
||||||
private Reversi game;
|
private Reversi game;
|
||||||
private ReversiAI ai;
|
private ReversiAI ai;
|
||||||
|
private ReversiAIML aiml;
|
||||||
|
private ReversiAISimple aiSimple;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setup() {
|
void setup() {
|
||||||
game = new Reversi();
|
game = new Reversi();
|
||||||
ai = new ReversiAI();
|
ai = new ReversiAI();
|
||||||
|
aiml = new ReversiAIML();
|
||||||
|
aiSimple = new ReversiAISimple();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -190,4 +197,35 @@ class ReversiTest {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testAIvsAIML(){
|
||||||
|
IO.println("Testing AI simple ...");
|
||||||
|
int totalGames = 5000;
|
||||||
|
int p1wins = 0;
|
||||||
|
int p2wins = 0;
|
||||||
|
int draws = 0;
|
||||||
|
for (int i = 0; i < totalGames; i++) {
|
||||||
|
game = new Reversi();
|
||||||
|
while (!game.isGameOver()) {
|
||||||
|
char curr = game.getCurrentPlayer();
|
||||||
|
if (curr == 'W') {
|
||||||
|
game.play(ai.findBestMove(game,5));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
game.play(ai.findBestMove(game,5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int winner = game.getWinner();
|
||||||
|
if (winner == 1) {
|
||||||
|
p1wins++;
|
||||||
|
}else if (winner == 2) {
|
||||||
|
p2wins++;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
draws++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user