added some useful testing methods.

made training slightly better.
This commit is contained in:
Ticho Hidding
2025-12-08 11:36:31 +01:00
parent 7e913ff50f
commit f6d90ed439
5 changed files with 110 additions and 49 deletions

View File

@@ -38,8 +38,8 @@ public final class Main {
}).start();
//NeuralNetwork nn = new NeuralNetwork();
//nn.init();
NeuralNetwork nn = new NeuralNetwork();
nn.init();
}
}

View File

@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAIML;
import org.toop.game.reversi.ReversiAISimple;
import java.io.File;
@@ -26,14 +27,17 @@ import java.util.ArrayList;
import java.util.List;
import static java.lang.Math.abs;
import static java.lang.Math.random;
public class NeuralNetwork {
private MultiLayerConfiguration conf;
private MultiLayerNetwork model;
private AI<Reversi> opponentAI;
private ReversiAI reversiAI = new ReversiAI();
private AI<Reversi> opponentRand = new ReversiAI();
private AI<Reversi> opponentSimple = new ReversiAISimple();
private AI<Reversi> opponentAIML = new ReversiAIML();
public NeuralNetwork() {}
@@ -61,7 +65,7 @@ public class NeuralNetwork {
model.init();
IO.println(model.summary());
model.setLearningRate(0.001);
model.setLearningRate(0.0003);
trainingLoop();
saveModel();
}
@@ -85,13 +89,15 @@ public class NeuralNetwork {
}
public void trainingLoop(){
int totalGames = 5000;
double epsilon = 0.1;
int totalGames = 50000;
double epsilon = 0.05;
long start = System.nanoTime();
for (int game = 0; game<totalGames; game++){
char modelPlayer = random()<0.5?'B':'W';
Reversi reversi = new Reversi();
opponentAI = getOpponentAI();
List<StateAction> gameHistory = new ArrayList<>();
GameState state = GameState.NORMAL;
@@ -100,7 +106,7 @@ public class NeuralNetwork {
while (state != GameState.DRAW && state != GameState.WIN){
char curr = reversi.getCurrentPlayer();
Move move;
if (curr == 'B') {
if (curr == modelPlayer) {
int[] input = reversi.getBoardInt();
if (Math.random() < epsilon) {
Move[] moves = reversi.getLegalMoves();
@@ -114,7 +120,7 @@ public class NeuralNetwork {
move = new Move(location, reversi.getCurrentPlayer());
}
}else{
move = reversiAI.findBestMove(reversi,5);
move = opponentAI.findBestMove(reversi,5);
}
state = reversi.play(move);
}
@@ -130,6 +136,10 @@ public class NeuralNetwork {
reward = 0;
}
if (modelPlayer == 'W'){
reward = -reward;
}
for (StateAction step : gameHistory){
trainFromHistory(step, reward);
@@ -167,6 +177,15 @@ public class NeuralNetwork {
return bestMove;
}
private AI<Reversi> getOpponentAI(){
return switch ((int) (Math.random() * 4)) {
case 0 -> opponentRand;
case 1 -> opponentSimple;
case 2 -> opponentAIML;
default -> opponentRand;
};
}
private void trainFromHistory(StateAction step, double reward){
double[] output = new double[64];
output[step.action] = reward;

View File

@@ -10,6 +10,8 @@ import org.toop.game.records.Move;
import java.io.IOException;
import java.io.InputStream;
import static java.lang.Math.random;
public class ReversiAIML extends AI<Reversi>{
MultiLayerNetwork model;
@@ -28,23 +30,27 @@ public class ReversiAIML extends AI<Reversi>{
INDArray boardInput = Nd4j.create(new int[][] { input });
INDArray prediction = model.output(boardInput);
int move = pickLegalMove(prediction, reversi);
int move = pickLegalMove(prediction,reversi);
return new Move(move, reversi.getCurrentPlayer());
}
private int pickLegalMove(INDArray prediction, Reversi reversi){
double[] probs = prediction.toDoubleVector();
private int pickLegalMove(INDArray prediction, Reversi reversi) {
double[] logits = prediction.toDoubleVector();
Move[] legalMoves = reversi.getLegalMoves();
if (legalMoves.length == 0) return -1;
int bestMove = legalMoves[0].position();
double bestVal = probs[bestMove];
double bestVal = logits[bestMove];
for (Move move : legalMoves){
if (probs[move.position()] > bestVal){
bestMove = move.position();
bestVal = probs[bestMove];
if (random() < 0.01){
return legalMoves[(int)(random()*legalMoves.length-.5)].position();
}
for (Move move : legalMoves) {
int pos = move.position();
if (logits[pos] > bestVal) {
bestMove = pos;
bestVal = logits[pos];
}
}
return bestMove;

View File

@@ -33,6 +33,12 @@ public class ReversiAISimple extends AI<Reversi> {
bestScore = numSco;
bestMoveScore = move;
}
if (numSco == bestScore || numOpt == bestOptions) {
if (Math.random() < 0.5) {
bestMoveOptions = move;
bestMoveScore = move;
}
}
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
}

View File

@@ -4,6 +4,7 @@ import java.util.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.toop.game.AI;
import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi;
@@ -18,6 +19,8 @@ class ReversiTest {
private ReversiAI ai;
private ReversiAIML aiml;
private ReversiAISimple aiSimple;
private AI<Reversi> player1;
private AI<Reversi> player2;
@BeforeEach
void setup() {
@@ -32,20 +35,20 @@ class ReversiTest {
@Test
void testCorrectStartPiecesPlaced() {
assertNotNull(game);
assertEquals('W',game.getBoard()[27]);
assertEquals('B',game.getBoard()[28]);
assertEquals('B',game.getBoard()[35]);
assertEquals('W',game.getBoard()[36]);
assertEquals('W', game.getBoard()[27]);
assertEquals('B', game.getBoard()[28]);
assertEquals('B', game.getBoard()[35]);
assertEquals('W', game.getBoard()[36]);
}
@Test
void testGetLegalMovesAtStart() {
Move[] moves = game.getLegalMoves();
List<Move> expectedMoves = List.of(
new Move(19,'B'),
new Move(26,'B'),
new Move(37,'B'),
new Move(44,'B')
new Move(19, 'B'),
new Move(26, 'B'),
new Move(37, 'B'),
new Move(44, 'B')
);
assertNotNull(moves);
assertTrue(moves.length > 0);
@@ -83,12 +86,12 @@ class ReversiTest {
@Test
void testCountScoreCorrectlyAtStart() {
long start = System.nanoTime();
long start = System.nanoTime();
Reversi.Score score = game.getScore();
assertEquals(2, score.player1Score()); // Black
assertEquals(2, score.player2Score()); // White
long end = System.nanoTime();
IO.println((end-start));
long end = System.nanoTime();
IO.println((end - start));
}
@Test
@@ -97,7 +100,7 @@ class ReversiTest {
game.play(new Move(20, 'W'));
Move[] moves = game.getLegalMoves();
List<Move> expectedMoves = List.of(
new Move(13,'B'),
new Move(13, 'B'),
new Move(21, 'B'),
new Move(29, 'B'),
new Move(37, 'B'),
@@ -110,11 +113,11 @@ class ReversiTest {
@Test
void testCountScoreCorrectlyAtEnd() {
for (int i = 0; i < 1; i++){
game = new Reversi();
for (int i = 0; i < 1; i++) {
game = new Reversi();
Move[] legalMoves = game.getLegalMoves();
while(legalMoves.length > 0) {
game.play(legalMoves[(int)(Math.random()*legalMoves.length)]);
while (legalMoves.length > 0) {
game.play(legalMoves[(int) (Math.random() * legalMoves.length)]);
legalMoves = game.getLegalMoves();
}
Reversi.Score score = game.getScore();
@@ -152,8 +155,8 @@ class ReversiTest {
game.play(new Move(60, 'W'));
game.play(new Move(59, 'B'));
assertEquals('B', game.getCurrentPlayer());
game.play(ai.findBestMove(game,5));
game.play(ai.findBestMove(game,5));
game.play(ai.findBestMove(game, 5));
game.play(ai.findBestMove(game, 5));
}
@Test
@@ -186,9 +189,9 @@ class ReversiTest {
@Test
void testAISelectsLegalMove() {
Move move = ai.findBestMove(game,4);
Move move = ai.findBestMove(game, 4);
assertNotNull(move);
assertTrue(containsMove(game.getLegalMoves(),move), "AI should always choose a legal move");
assertTrue(containsMove(game.getLegalMoves(), move), "AI should always choose a legal move");
}
private boolean containsMove(Move[] moves, Move move) {
@@ -199,33 +202,60 @@ class ReversiTest {
}
@Test
void testAIvsAIML(){
IO.println("Testing AI simple ...");
int totalGames = 5000;
void testAis() {
player1 = aiml;
player2 = ai;
testAIvsAIML();
player2 = aiSimple;
testAIvsAIML();
player1 = ai;
testAIvsAIML();
player2 = aiml;
testAIvsAIML();
player1 = aiml;
testAIvsAIML();
player1 = aiSimple;
testAIvsAIML();
}
@Test
void testAIvsAIML() {
if(player1 == null || player2 == null) {
player1 = aiml;
player2 = ai;
}
int totalGames = 2000;
IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games");
int p1wins = 0;
int p2wins = 0;
int draws = 0;
List<Integer> moves = new ArrayList<>();
for (int i = 0; i < totalGames; i++) {
game = new Reversi();
game = new Reversi();
while (!game.isGameOver()) {
char curr = game.getCurrentPlayer();
if (curr == 'W') {
game.play(ai.findBestMove(game,5));
}
else {
game.play(ai.findBestMove(game,5));
Move move;
if (curr == 'B') {
move = player1.findBestMove(game, 5);
} else {
move = player2.findBestMove(game, 5);
}
if (i%500 == 0) moves.add(move.position());
game.play(move);
}
if (i%500 == 0) {
IO.println(moves);
moves.clear();
}
int winner = game.getWinner();
if (winner == 1) {
p1wins++;
}else if (winner == 2) {
} else if (winner == 2) {
p2wins++;
}
else{
} else {
draws++;
}
}
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws);
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames + "\np2wins: " + p2wins + " draws: " + draws);
}
}