added some useful testing methods.

made training slightly better.
This commit is contained in:
Ticho Hidding
2025-12-08 11:36:31 +01:00
parent 7e913ff50f
commit f6d90ed439
5 changed files with 110 additions and 49 deletions

View File

@@ -38,8 +38,8 @@ public final class Main {
}).start(); }).start();
//NeuralNetwork nn = new NeuralNetwork(); NeuralNetwork nn = new NeuralNetwork();
//nn.init(); nn.init();
} }
} }

View File

@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move; import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi; import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI; import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAIML;
import org.toop.game.reversi.ReversiAISimple; import org.toop.game.reversi.ReversiAISimple;
import java.io.File; import java.io.File;
@@ -26,14 +27,17 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static java.lang.Math.abs; import static java.lang.Math.abs;
import static java.lang.Math.random;
public class NeuralNetwork { public class NeuralNetwork {
private MultiLayerConfiguration conf; private MultiLayerConfiguration conf;
private MultiLayerNetwork model; private MultiLayerNetwork model;
private AI<Reversi> opponentAI;
private ReversiAI reversiAI = new ReversiAI(); private ReversiAI reversiAI = new ReversiAI();
private AI<Reversi> opponentRand = new ReversiAI(); private AI<Reversi> opponentRand = new ReversiAI();
private AI<Reversi> opponentSimple = new ReversiAISimple(); private AI<Reversi> opponentSimple = new ReversiAISimple();
private AI<Reversi> opponentAIML = new ReversiAIML();
public NeuralNetwork() {} public NeuralNetwork() {}
@@ -61,7 +65,7 @@ public class NeuralNetwork {
model.init(); model.init();
IO.println(model.summary()); IO.println(model.summary());
model.setLearningRate(0.001); model.setLearningRate(0.0003);
trainingLoop(); trainingLoop();
saveModel(); saveModel();
} }
@@ -85,13 +89,15 @@ public class NeuralNetwork {
} }
public void trainingLoop(){ public void trainingLoop(){
int totalGames = 5000; int totalGames = 50000;
double epsilon = 0.1; double epsilon = 0.05;
long start = System.nanoTime(); long start = System.nanoTime();
for (int game = 0; game<totalGames; game++){ for (int game = 0; game<totalGames; game++){
char modelPlayer = random()<0.5?'B':'W';
Reversi reversi = new Reversi(); Reversi reversi = new Reversi();
opponentAI = getOpponentAI();
List<StateAction> gameHistory = new ArrayList<>(); List<StateAction> gameHistory = new ArrayList<>();
GameState state = GameState.NORMAL; GameState state = GameState.NORMAL;
@@ -100,7 +106,7 @@ public class NeuralNetwork {
while (state != GameState.DRAW && state != GameState.WIN){ while (state != GameState.DRAW && state != GameState.WIN){
char curr = reversi.getCurrentPlayer(); char curr = reversi.getCurrentPlayer();
Move move; Move move;
if (curr == 'B') { if (curr == modelPlayer) {
int[] input = reversi.getBoardInt(); int[] input = reversi.getBoardInt();
if (Math.random() < epsilon) { if (Math.random() < epsilon) {
Move[] moves = reversi.getLegalMoves(); Move[] moves = reversi.getLegalMoves();
@@ -114,7 +120,7 @@ public class NeuralNetwork {
move = new Move(location, reversi.getCurrentPlayer()); move = new Move(location, reversi.getCurrentPlayer());
} }
}else{ }else{
move = reversiAI.findBestMove(reversi,5); move = opponentAI.findBestMove(reversi,5);
} }
state = reversi.play(move); state = reversi.play(move);
} }
@@ -130,6 +136,10 @@ public class NeuralNetwork {
reward = 0; reward = 0;
} }
if (modelPlayer == 'W'){
reward = -reward;
}
for (StateAction step : gameHistory){ for (StateAction step : gameHistory){
trainFromHistory(step, reward); trainFromHistory(step, reward);
@@ -167,6 +177,15 @@ public class NeuralNetwork {
return bestMove; return bestMove;
} }
private AI<Reversi> getOpponentAI(){
return switch ((int) (Math.random() * 4)) {
case 0 -> opponentRand;
case 1 -> opponentSimple;
case 2 -> opponentAIML;
default -> opponentRand;
};
}
private void trainFromHistory(StateAction step, double reward){ private void trainFromHistory(StateAction step, double reward){
double[] output = new double[64]; double[] output = new double[64];
output[step.action] = reward; output[step.action] = reward;

View File

@@ -10,6 +10,8 @@ import org.toop.game.records.Move;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import static java.lang.Math.random;
public class ReversiAIML extends AI<Reversi>{ public class ReversiAIML extends AI<Reversi>{
MultiLayerNetwork model; MultiLayerNetwork model;
@@ -33,18 +35,22 @@ public class ReversiAIML extends AI<Reversi>{
} }
private int pickLegalMove(INDArray prediction, Reversi reversi) { private int pickLegalMove(INDArray prediction, Reversi reversi) {
double[] probs = prediction.toDoubleVector(); double[] logits = prediction.toDoubleVector();
Move[] legalMoves = reversi.getLegalMoves(); Move[] legalMoves = reversi.getLegalMoves();
if (legalMoves.length == 0) return -1; if (legalMoves.length == 0) return -1;
int bestMove = legalMoves[0].position(); int bestMove = legalMoves[0].position();
double bestVal = probs[bestMove]; double bestVal = logits[bestMove];
if (random() < 0.01){
return legalMoves[(int)(random()*legalMoves.length-.5)].position();
}
for (Move move : legalMoves) { for (Move move : legalMoves) {
if (probs[move.position()] > bestVal){ int pos = move.position();
bestMove = move.position(); if (logits[pos] > bestVal) {
bestVal = probs[bestMove]; bestMove = pos;
bestVal = logits[pos];
} }
} }
return bestMove; return bestMove;

View File

@@ -33,6 +33,12 @@ public class ReversiAISimple extends AI<Reversi> {
bestScore = numSco; bestScore = numSco;
bestMoveScore = move; bestMoveScore = move;
} }
if (numSco == bestScore || numOpt == bestOptions) {
if (Math.random() < 0.5) {
bestMoveOptions = move;
bestMoveScore = move;
}
}
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco); //IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
} }

View File

@@ -4,6 +4,7 @@ import java.util.*;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.toop.game.AI;
import org.toop.game.enumerators.GameState; import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move; import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi; import org.toop.game.reversi.Reversi;
@@ -18,6 +19,8 @@ class ReversiTest {
private ReversiAI ai; private ReversiAI ai;
private ReversiAIML aiml; private ReversiAIML aiml;
private ReversiAISimple aiSimple; private ReversiAISimple aiSimple;
private AI<Reversi> player1;
private AI<Reversi> player2;
@BeforeEach @BeforeEach
void setup() { void setup() {
@@ -198,31 +201,58 @@ class ReversiTest {
return false; return false;
} }
@Test
void testAis() {
player1 = aiml;
player2 = ai;
testAIvsAIML();
player2 = aiSimple;
testAIvsAIML();
player1 = ai;
testAIvsAIML();
player2 = aiml;
testAIvsAIML();
player1 = aiml;
testAIvsAIML();
player1 = aiSimple;
testAIvsAIML();
}
@Test @Test
void testAIvsAIML() { void testAIvsAIML() {
IO.println("Testing AI simple ..."); if(player1 == null || player2 == null) {
int totalGames = 5000; player1 = aiml;
player2 = ai;
}
int totalGames = 2000;
IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games");
int p1wins = 0; int p1wins = 0;
int p2wins = 0; int p2wins = 0;
int draws = 0; int draws = 0;
List<Integer> moves = new ArrayList<>();
for (int i = 0; i < totalGames; i++) { for (int i = 0; i < totalGames; i++) {
game = new Reversi(); game = new Reversi();
while (!game.isGameOver()) { while (!game.isGameOver()) {
char curr = game.getCurrentPlayer(); char curr = game.getCurrentPlayer();
if (curr == 'W') { Move move;
game.play(ai.findBestMove(game,5)); if (curr == 'B') {
move = player1.findBestMove(game, 5);
} else {
move = player2.findBestMove(game, 5);
} }
else { if (i%500 == 0) moves.add(move.position());
game.play(ai.findBestMove(game,5)); game.play(move);
} }
if (i%500 == 0) {
IO.println(moves);
moves.clear();
} }
int winner = game.getWinner(); int winner = game.getWinner();
if (winner == 1) { if (winner == 1) {
p1wins++; p1wins++;
} else if (winner == 2) { } else if (winner == 2) {
p2wins++; p2wins++;
} } else {
else{
draws++; draws++;
} }
} }