added some useful testing methods.

made training slightly better.
This commit is contained in:
Ticho Hidding
2025-12-08 11:36:31 +01:00
parent 7e913ff50f
commit f6d90ed439
5 changed files with 110 additions and 49 deletions

View File

@@ -38,8 +38,8 @@ public final class Main {
}).start(); }).start();
//NeuralNetwork nn = new NeuralNetwork(); NeuralNetwork nn = new NeuralNetwork();
//nn.init(); nn.init();
} }
} }

View File

@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move; import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi; import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI; import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAIML;
import org.toop.game.reversi.ReversiAISimple; import org.toop.game.reversi.ReversiAISimple;
import java.io.File; import java.io.File;
@@ -26,14 +27,17 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static java.lang.Math.abs; import static java.lang.Math.abs;
import static java.lang.Math.random;
public class NeuralNetwork { public class NeuralNetwork {
private MultiLayerConfiguration conf; private MultiLayerConfiguration conf;
private MultiLayerNetwork model; private MultiLayerNetwork model;
private AI<Reversi> opponentAI;
private ReversiAI reversiAI = new ReversiAI(); private ReversiAI reversiAI = new ReversiAI();
private AI<Reversi> opponentRand = new ReversiAI(); private AI<Reversi> opponentRand = new ReversiAI();
private AI<Reversi> opponentSimple = new ReversiAISimple(); private AI<Reversi> opponentSimple = new ReversiAISimple();
private AI<Reversi> opponentAIML = new ReversiAIML();
public NeuralNetwork() {} public NeuralNetwork() {}
@@ -61,7 +65,7 @@ public class NeuralNetwork {
model.init(); model.init();
IO.println(model.summary()); IO.println(model.summary());
model.setLearningRate(0.001); model.setLearningRate(0.0003);
trainingLoop(); trainingLoop();
saveModel(); saveModel();
} }
@@ -85,13 +89,15 @@ public class NeuralNetwork {
} }
public void trainingLoop(){ public void trainingLoop(){
int totalGames = 5000; int totalGames = 50000;
double epsilon = 0.1; double epsilon = 0.05;
long start = System.nanoTime(); long start = System.nanoTime();
for (int game = 0; game<totalGames; game++){ for (int game = 0; game<totalGames; game++){
char modelPlayer = random()<0.5?'B':'W';
Reversi reversi = new Reversi(); Reversi reversi = new Reversi();
opponentAI = getOpponentAI();
List<StateAction> gameHistory = new ArrayList<>(); List<StateAction> gameHistory = new ArrayList<>();
GameState state = GameState.NORMAL; GameState state = GameState.NORMAL;
@@ -100,7 +106,7 @@ public class NeuralNetwork {
while (state != GameState.DRAW && state != GameState.WIN){ while (state != GameState.DRAW && state != GameState.WIN){
char curr = reversi.getCurrentPlayer(); char curr = reversi.getCurrentPlayer();
Move move; Move move;
if (curr == 'B') { if (curr == modelPlayer) {
int[] input = reversi.getBoardInt(); int[] input = reversi.getBoardInt();
if (Math.random() < epsilon) { if (Math.random() < epsilon) {
Move[] moves = reversi.getLegalMoves(); Move[] moves = reversi.getLegalMoves();
@@ -114,7 +120,7 @@ public class NeuralNetwork {
move = new Move(location, reversi.getCurrentPlayer()); move = new Move(location, reversi.getCurrentPlayer());
} }
}else{ }else{
move = reversiAI.findBestMove(reversi,5); move = opponentAI.findBestMove(reversi,5);
} }
state = reversi.play(move); state = reversi.play(move);
} }
@@ -130,6 +136,10 @@ public class NeuralNetwork {
reward = 0; reward = 0;
} }
if (modelPlayer == 'W'){
reward = -reward;
}
for (StateAction step : gameHistory){ for (StateAction step : gameHistory){
trainFromHistory(step, reward); trainFromHistory(step, reward);
@@ -167,6 +177,15 @@ public class NeuralNetwork {
return bestMove; return bestMove;
} }
private AI<Reversi> getOpponentAI(){
return switch ((int) (Math.random() * 4)) {
case 0 -> opponentRand;
case 1 -> opponentSimple;
case 2 -> opponentAIML;
default -> opponentRand;
};
}
private void trainFromHistory(StateAction step, double reward){ private void trainFromHistory(StateAction step, double reward){
double[] output = new double[64]; double[] output = new double[64];
output[step.action] = reward; output[step.action] = reward;

View File

@@ -10,6 +10,8 @@ import org.toop.game.records.Move;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import static java.lang.Math.random;
public class ReversiAIML extends AI<Reversi>{ public class ReversiAIML extends AI<Reversi>{
MultiLayerNetwork model; MultiLayerNetwork model;
@@ -28,23 +30,27 @@ public class ReversiAIML extends AI<Reversi>{
INDArray boardInput = Nd4j.create(new int[][] { input }); INDArray boardInput = Nd4j.create(new int[][] { input });
INDArray prediction = model.output(boardInput); INDArray prediction = model.output(boardInput);
int move = pickLegalMove(prediction, reversi); int move = pickLegalMove(prediction,reversi);
return new Move(move, reversi.getCurrentPlayer()); return new Move(move, reversi.getCurrentPlayer());
} }
private int pickLegalMove(INDArray prediction, Reversi reversi){ private int pickLegalMove(INDArray prediction, Reversi reversi) {
double[] probs = prediction.toDoubleVector(); double[] logits = prediction.toDoubleVector();
Move[] legalMoves = reversi.getLegalMoves(); Move[] legalMoves = reversi.getLegalMoves();
if (legalMoves.length == 0) return -1; if (legalMoves.length == 0) return -1;
int bestMove = legalMoves[0].position(); int bestMove = legalMoves[0].position();
double bestVal = probs[bestMove]; double bestVal = logits[bestMove];
for (Move move : legalMoves){ if (random() < 0.01){
if (probs[move.position()] > bestVal){ return legalMoves[(int)(random()*legalMoves.length-.5)].position();
bestMove = move.position(); }
bestVal = probs[bestMove]; for (Move move : legalMoves) {
int pos = move.position();
if (logits[pos] > bestVal) {
bestMove = pos;
bestVal = logits[pos];
} }
} }
return bestMove; return bestMove;

View File

@@ -33,6 +33,12 @@ public class ReversiAISimple extends AI<Reversi> {
bestScore = numSco; bestScore = numSco;
bestMoveScore = move; bestMoveScore = move;
} }
if (numSco == bestScore || numOpt == bestOptions) {
if (Math.random() < 0.5) {
bestMoveOptions = move;
bestMoveScore = move;
}
}
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco); //IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
} }

View File

@@ -4,6 +4,7 @@ import java.util.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.toop.game.AI;
import org.toop.game.enumerators.GameState; import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move; import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi; import org.toop.game.reversi.Reversi;
@@ -18,6 +19,8 @@ class ReversiTest {
private ReversiAI ai; private ReversiAI ai;
private ReversiAIML aiml; private ReversiAIML aiml;
private ReversiAISimple aiSimple; private ReversiAISimple aiSimple;
private AI<Reversi> player1;
private AI<Reversi> player2;
@BeforeEach @BeforeEach
void setup() { void setup() {
@@ -32,20 +35,20 @@ class ReversiTest {
@Test @Test
void testCorrectStartPiecesPlaced() { void testCorrectStartPiecesPlaced() {
assertNotNull(game); assertNotNull(game);
assertEquals('W',game.getBoard()[27]); assertEquals('W', game.getBoard()[27]);
assertEquals('B',game.getBoard()[28]); assertEquals('B', game.getBoard()[28]);
assertEquals('B',game.getBoard()[35]); assertEquals('B', game.getBoard()[35]);
assertEquals('W',game.getBoard()[36]); assertEquals('W', game.getBoard()[36]);
} }
@Test @Test
void testGetLegalMovesAtStart() { void testGetLegalMovesAtStart() {
Move[] moves = game.getLegalMoves(); Move[] moves = game.getLegalMoves();
List<Move> expectedMoves = List.of( List<Move> expectedMoves = List.of(
new Move(19,'B'), new Move(19, 'B'),
new Move(26,'B'), new Move(26, 'B'),
new Move(37,'B'), new Move(37, 'B'),
new Move(44,'B') new Move(44, 'B')
); );
assertNotNull(moves); assertNotNull(moves);
assertTrue(moves.length > 0); assertTrue(moves.length > 0);
@@ -83,12 +86,12 @@ class ReversiTest {
@Test @Test
void testCountScoreCorrectlyAtStart() { void testCountScoreCorrectlyAtStart() {
long start = System.nanoTime(); long start = System.nanoTime();
Reversi.Score score = game.getScore(); Reversi.Score score = game.getScore();
assertEquals(2, score.player1Score()); // Black assertEquals(2, score.player1Score()); // Black
assertEquals(2, score.player2Score()); // White assertEquals(2, score.player2Score()); // White
long end = System.nanoTime(); long end = System.nanoTime();
IO.println((end-start)); IO.println((end - start));
} }
@Test @Test
@@ -97,7 +100,7 @@ class ReversiTest {
game.play(new Move(20, 'W')); game.play(new Move(20, 'W'));
Move[] moves = game.getLegalMoves(); Move[] moves = game.getLegalMoves();
List<Move> expectedMoves = List.of( List<Move> expectedMoves = List.of(
new Move(13,'B'), new Move(13, 'B'),
new Move(21, 'B'), new Move(21, 'B'),
new Move(29, 'B'), new Move(29, 'B'),
new Move(37, 'B'), new Move(37, 'B'),
@@ -110,11 +113,11 @@ class ReversiTest {
@Test @Test
void testCountScoreCorrectlyAtEnd() { void testCountScoreCorrectlyAtEnd() {
for (int i = 0; i < 1; i++){ for (int i = 0; i < 1; i++) {
game = new Reversi(); game = new Reversi();
Move[] legalMoves = game.getLegalMoves(); Move[] legalMoves = game.getLegalMoves();
while(legalMoves.length > 0) { while (legalMoves.length > 0) {
game.play(legalMoves[(int)(Math.random()*legalMoves.length)]); game.play(legalMoves[(int) (Math.random() * legalMoves.length)]);
legalMoves = game.getLegalMoves(); legalMoves = game.getLegalMoves();
} }
Reversi.Score score = game.getScore(); Reversi.Score score = game.getScore();
@@ -152,8 +155,8 @@ class ReversiTest {
game.play(new Move(60, 'W')); game.play(new Move(60, 'W'));
game.play(new Move(59, 'B')); game.play(new Move(59, 'B'));
assertEquals('B', game.getCurrentPlayer()); assertEquals('B', game.getCurrentPlayer());
game.play(ai.findBestMove(game,5)); game.play(ai.findBestMove(game, 5));
game.play(ai.findBestMove(game,5)); game.play(ai.findBestMove(game, 5));
} }
@Test @Test
@@ -186,9 +189,9 @@ class ReversiTest {
@Test @Test
void testAISelectsLegalMove() { void testAISelectsLegalMove() {
Move move = ai.findBestMove(game,4); Move move = ai.findBestMove(game, 4);
assertNotNull(move); assertNotNull(move);
assertTrue(containsMove(game.getLegalMoves(),move), "AI should always choose a legal move"); assertTrue(containsMove(game.getLegalMoves(), move), "AI should always choose a legal move");
} }
private boolean containsMove(Move[] moves, Move move) { private boolean containsMove(Move[] moves, Move move) {
@@ -199,33 +202,60 @@ class ReversiTest {
} }
@Test @Test
void testAIvsAIML(){ void testAis() {
IO.println("Testing AI simple ..."); player1 = aiml;
int totalGames = 5000; player2 = ai;
testAIvsAIML();
player2 = aiSimple;
testAIvsAIML();
player1 = ai;
testAIvsAIML();
player2 = aiml;
testAIvsAIML();
player1 = aiml;
testAIvsAIML();
player1 = aiSimple;
testAIvsAIML();
}
@Test
void testAIvsAIML() {
if(player1 == null || player2 == null) {
player1 = aiml;
player2 = ai;
}
int totalGames = 2000;
IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games");
int p1wins = 0; int p1wins = 0;
int p2wins = 0; int p2wins = 0;
int draws = 0; int draws = 0;
List<Integer> moves = new ArrayList<>();
for (int i = 0; i < totalGames; i++) { for (int i = 0; i < totalGames; i++) {
game = new Reversi(); game = new Reversi();
while (!game.isGameOver()) { while (!game.isGameOver()) {
char curr = game.getCurrentPlayer(); char curr = game.getCurrentPlayer();
if (curr == 'W') { Move move;
game.play(ai.findBestMove(game,5)); if (curr == 'B') {
} move = player1.findBestMove(game, 5);
else { } else {
game.play(ai.findBestMove(game,5)); move = player2.findBestMove(game, 5);
} }
if (i%500 == 0) moves.add(move.position());
game.play(move);
}
if (i%500 == 0) {
IO.println(moves);
moves.clear();
} }
int winner = game.getWinner(); int winner = game.getWinner();
if (winner == 1) { if (winner == 1) {
p1wins++; p1wins++;
}else if (winner == 2) { } else if (winner == 2) {
p2wins++; p2wins++;
} } else {
else{
draws++; draws++;
} }
} }
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws); IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames + "\np2wins: " + p2wins + " draws: " + draws);
} }
} }