mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 02:44:50 +00:00
added some useful testing methods.
made training slightly better.
This commit is contained in:
@@ -38,8 +38,8 @@ public final class Main {
|
||||
}).start();
|
||||
|
||||
|
||||
//NeuralNetwork nn = new NeuralNetwork();
|
||||
//nn.init();
|
||||
NeuralNetwork nn = new NeuralNetwork();
|
||||
nn.init();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
|
||||
import org.toop.game.records.Move;
|
||||
import org.toop.game.reversi.Reversi;
|
||||
import org.toop.game.reversi.ReversiAI;
|
||||
import org.toop.game.reversi.ReversiAIML;
|
||||
import org.toop.game.reversi.ReversiAISimple;
|
||||
|
||||
import java.io.File;
|
||||
@@ -26,14 +27,17 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static java.lang.Math.abs;
|
||||
import static java.lang.Math.random;
|
||||
|
||||
public class NeuralNetwork {
|
||||
|
||||
private MultiLayerConfiguration conf;
|
||||
private MultiLayerNetwork model;
|
||||
private AI<Reversi> opponentAI;
|
||||
private ReversiAI reversiAI = new ReversiAI();
|
||||
private AI<Reversi> opponentRand = new ReversiAI();
|
||||
private AI<Reversi> opponentSimple = new ReversiAISimple();
|
||||
private AI<Reversi> opponentAIML = new ReversiAIML();
|
||||
|
||||
|
||||
public NeuralNetwork() {}
|
||||
@@ -61,7 +65,7 @@ public class NeuralNetwork {
|
||||
model.init();
|
||||
IO.println(model.summary());
|
||||
|
||||
model.setLearningRate(0.001);
|
||||
model.setLearningRate(0.0003);
|
||||
trainingLoop();
|
||||
saveModel();
|
||||
}
|
||||
@@ -85,13 +89,15 @@ public class NeuralNetwork {
|
||||
}
|
||||
|
||||
public void trainingLoop(){
|
||||
int totalGames = 5000;
|
||||
double epsilon = 0.1;
|
||||
int totalGames = 50000;
|
||||
double epsilon = 0.05;
|
||||
|
||||
long start = System.nanoTime();
|
||||
|
||||
for (int game = 0; game<totalGames; game++){
|
||||
char modelPlayer = random()<0.5?'B':'W';
|
||||
Reversi reversi = new Reversi();
|
||||
opponentAI = getOpponentAI();
|
||||
List<StateAction> gameHistory = new ArrayList<>();
|
||||
GameState state = GameState.NORMAL;
|
||||
|
||||
@@ -100,7 +106,7 @@ public class NeuralNetwork {
|
||||
while (state != GameState.DRAW && state != GameState.WIN){
|
||||
char curr = reversi.getCurrentPlayer();
|
||||
Move move;
|
||||
if (curr == 'B') {
|
||||
if (curr == modelPlayer) {
|
||||
int[] input = reversi.getBoardInt();
|
||||
if (Math.random() < epsilon) {
|
||||
Move[] moves = reversi.getLegalMoves();
|
||||
@@ -114,7 +120,7 @@ public class NeuralNetwork {
|
||||
move = new Move(location, reversi.getCurrentPlayer());
|
||||
}
|
||||
}else{
|
||||
move = reversiAI.findBestMove(reversi,5);
|
||||
move = opponentAI.findBestMove(reversi,5);
|
||||
}
|
||||
state = reversi.play(move);
|
||||
}
|
||||
@@ -130,6 +136,10 @@ public class NeuralNetwork {
|
||||
reward = 0;
|
||||
}
|
||||
|
||||
if (modelPlayer == 'W'){
|
||||
reward = -reward;
|
||||
}
|
||||
|
||||
|
||||
for (StateAction step : gameHistory){
|
||||
trainFromHistory(step, reward);
|
||||
@@ -167,6 +177,15 @@ public class NeuralNetwork {
|
||||
return bestMove;
|
||||
}
|
||||
|
||||
private AI<Reversi> getOpponentAI(){
|
||||
return switch ((int) (Math.random() * 4)) {
|
||||
case 0 -> opponentRand;
|
||||
case 1 -> opponentSimple;
|
||||
case 2 -> opponentAIML;
|
||||
default -> opponentRand;
|
||||
};
|
||||
}
|
||||
|
||||
private void trainFromHistory(StateAction step, double reward){
|
||||
double[] output = new double[64];
|
||||
output[step.action] = reward;
|
||||
|
||||
@@ -10,6 +10,8 @@ import org.toop.game.records.Move;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
import static java.lang.Math.random;
|
||||
|
||||
public class ReversiAIML extends AI<Reversi>{
|
||||
|
||||
MultiLayerNetwork model;
|
||||
@@ -28,23 +30,27 @@ public class ReversiAIML extends AI<Reversi>{
|
||||
INDArray boardInput = Nd4j.create(new int[][] { input });
|
||||
INDArray prediction = model.output(boardInput);
|
||||
|
||||
int move = pickLegalMove(prediction, reversi);
|
||||
int move = pickLegalMove(prediction,reversi);
|
||||
return new Move(move, reversi.getCurrentPlayer());
|
||||
}
|
||||
|
||||
private int pickLegalMove(INDArray prediction, Reversi reversi){
|
||||
double[] probs = prediction.toDoubleVector();
|
||||
private int pickLegalMove(INDArray prediction, Reversi reversi) {
|
||||
double[] logits = prediction.toDoubleVector();
|
||||
Move[] legalMoves = reversi.getLegalMoves();
|
||||
|
||||
if (legalMoves.length == 0) return -1;
|
||||
|
||||
int bestMove = legalMoves[0].position();
|
||||
double bestVal = probs[bestMove];
|
||||
double bestVal = logits[bestMove];
|
||||
|
||||
for (Move move : legalMoves){
|
||||
if (probs[move.position()] > bestVal){
|
||||
bestMove = move.position();
|
||||
bestVal = probs[bestMove];
|
||||
if (random() < 0.01){
|
||||
return legalMoves[(int)(random()*legalMoves.length-.5)].position();
|
||||
}
|
||||
for (Move move : legalMoves) {
|
||||
int pos = move.position();
|
||||
if (logits[pos] > bestVal) {
|
||||
bestMove = pos;
|
||||
bestVal = logits[pos];
|
||||
}
|
||||
}
|
||||
return bestMove;
|
||||
|
||||
@@ -33,6 +33,12 @@ public class ReversiAISimple extends AI<Reversi> {
|
||||
bestScore = numSco;
|
||||
bestMoveScore = move;
|
||||
}
|
||||
if (numSco == bestScore || numOpt == bestOptions) {
|
||||
if (Math.random() < 0.5) {
|
||||
bestMoveOptions = move;
|
||||
bestMoveScore = move;
|
||||
}
|
||||
}
|
||||
|
||||
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import java.util.*;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.toop.game.AI;
|
||||
import org.toop.game.enumerators.GameState;
|
||||
import org.toop.game.records.Move;
|
||||
import org.toop.game.reversi.Reversi;
|
||||
@@ -18,6 +19,8 @@ class ReversiTest {
|
||||
private ReversiAI ai;
|
||||
private ReversiAIML aiml;
|
||||
private ReversiAISimple aiSimple;
|
||||
private AI<Reversi> player1;
|
||||
private AI<Reversi> player2;
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
@@ -32,20 +35,20 @@ class ReversiTest {
|
||||
@Test
|
||||
void testCorrectStartPiecesPlaced() {
|
||||
assertNotNull(game);
|
||||
assertEquals('W',game.getBoard()[27]);
|
||||
assertEquals('B',game.getBoard()[28]);
|
||||
assertEquals('B',game.getBoard()[35]);
|
||||
assertEquals('W',game.getBoard()[36]);
|
||||
assertEquals('W', game.getBoard()[27]);
|
||||
assertEquals('B', game.getBoard()[28]);
|
||||
assertEquals('B', game.getBoard()[35]);
|
||||
assertEquals('W', game.getBoard()[36]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetLegalMovesAtStart() {
|
||||
Move[] moves = game.getLegalMoves();
|
||||
List<Move> expectedMoves = List.of(
|
||||
new Move(19,'B'),
|
||||
new Move(26,'B'),
|
||||
new Move(37,'B'),
|
||||
new Move(44,'B')
|
||||
new Move(19, 'B'),
|
||||
new Move(26, 'B'),
|
||||
new Move(37, 'B'),
|
||||
new Move(44, 'B')
|
||||
);
|
||||
assertNotNull(moves);
|
||||
assertTrue(moves.length > 0);
|
||||
@@ -83,12 +86,12 @@ class ReversiTest {
|
||||
|
||||
@Test
|
||||
void testCountScoreCorrectlyAtStart() {
|
||||
long start = System.nanoTime();
|
||||
long start = System.nanoTime();
|
||||
Reversi.Score score = game.getScore();
|
||||
assertEquals(2, score.player1Score()); // Black
|
||||
assertEquals(2, score.player2Score()); // White
|
||||
long end = System.nanoTime();
|
||||
IO.println((end-start));
|
||||
long end = System.nanoTime();
|
||||
IO.println((end - start));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -97,7 +100,7 @@ class ReversiTest {
|
||||
game.play(new Move(20, 'W'));
|
||||
Move[] moves = game.getLegalMoves();
|
||||
List<Move> expectedMoves = List.of(
|
||||
new Move(13,'B'),
|
||||
new Move(13, 'B'),
|
||||
new Move(21, 'B'),
|
||||
new Move(29, 'B'),
|
||||
new Move(37, 'B'),
|
||||
@@ -110,11 +113,11 @@ class ReversiTest {
|
||||
|
||||
@Test
|
||||
void testCountScoreCorrectlyAtEnd() {
|
||||
for (int i = 0; i < 1; i++){
|
||||
game = new Reversi();
|
||||
for (int i = 0; i < 1; i++) {
|
||||
game = new Reversi();
|
||||
Move[] legalMoves = game.getLegalMoves();
|
||||
while(legalMoves.length > 0) {
|
||||
game.play(legalMoves[(int)(Math.random()*legalMoves.length)]);
|
||||
while (legalMoves.length > 0) {
|
||||
game.play(legalMoves[(int) (Math.random() * legalMoves.length)]);
|
||||
legalMoves = game.getLegalMoves();
|
||||
}
|
||||
Reversi.Score score = game.getScore();
|
||||
@@ -152,8 +155,8 @@ class ReversiTest {
|
||||
game.play(new Move(60, 'W'));
|
||||
game.play(new Move(59, 'B'));
|
||||
assertEquals('B', game.getCurrentPlayer());
|
||||
game.play(ai.findBestMove(game,5));
|
||||
game.play(ai.findBestMove(game,5));
|
||||
game.play(ai.findBestMove(game, 5));
|
||||
game.play(ai.findBestMove(game, 5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -186,9 +189,9 @@ class ReversiTest {
|
||||
|
||||
@Test
|
||||
void testAISelectsLegalMove() {
|
||||
Move move = ai.findBestMove(game,4);
|
||||
Move move = ai.findBestMove(game, 4);
|
||||
assertNotNull(move);
|
||||
assertTrue(containsMove(game.getLegalMoves(),move), "AI should always choose a legal move");
|
||||
assertTrue(containsMove(game.getLegalMoves(), move), "AI should always choose a legal move");
|
||||
}
|
||||
|
||||
private boolean containsMove(Move[] moves, Move move) {
|
||||
@@ -199,33 +202,60 @@ class ReversiTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAIvsAIML(){
|
||||
IO.println("Testing AI simple ...");
|
||||
int totalGames = 5000;
|
||||
void testAis() {
|
||||
player1 = aiml;
|
||||
player2 = ai;
|
||||
testAIvsAIML();
|
||||
player2 = aiSimple;
|
||||
testAIvsAIML();
|
||||
player1 = ai;
|
||||
testAIvsAIML();
|
||||
player2 = aiml;
|
||||
testAIvsAIML();
|
||||
player1 = aiml;
|
||||
testAIvsAIML();
|
||||
player1 = aiSimple;
|
||||
testAIvsAIML();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAIvsAIML() {
|
||||
if(player1 == null || player2 == null) {
|
||||
player1 = aiml;
|
||||
player2 = ai;
|
||||
}
|
||||
int totalGames = 2000;
|
||||
IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games");
|
||||
int p1wins = 0;
|
||||
int p2wins = 0;
|
||||
int draws = 0;
|
||||
List<Integer> moves = new ArrayList<>();
|
||||
for (int i = 0; i < totalGames; i++) {
|
||||
game = new Reversi();
|
||||
game = new Reversi();
|
||||
while (!game.isGameOver()) {
|
||||
char curr = game.getCurrentPlayer();
|
||||
if (curr == 'W') {
|
||||
game.play(ai.findBestMove(game,5));
|
||||
}
|
||||
else {
|
||||
game.play(ai.findBestMove(game,5));
|
||||
Move move;
|
||||
if (curr == 'B') {
|
||||
move = player1.findBestMove(game, 5);
|
||||
} else {
|
||||
move = player2.findBestMove(game, 5);
|
||||
}
|
||||
if (i%500 == 0) moves.add(move.position());
|
||||
game.play(move);
|
||||
}
|
||||
if (i%500 == 0) {
|
||||
IO.println(moves);
|
||||
moves.clear();
|
||||
}
|
||||
int winner = game.getWinner();
|
||||
if (winner == 1) {
|
||||
p1wins++;
|
||||
}else if (winner == 2) {
|
||||
} else if (winner == 2) {
|
||||
p2wins++;
|
||||
}
|
||||
else{
|
||||
} else {
|
||||
draws++;
|
||||
}
|
||||
}
|
||||
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws);
|
||||
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames + "\np2wins: " + p2wins + " draws: " + draws);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user