added some useful testing methods.

made training slightly better.
This commit is contained in:
Ticho Hidding
2025-12-08 11:36:31 +01:00
parent 7e913ff50f
commit f6d90ed439
5 changed files with 110 additions and 49 deletions

View File

@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
import org.toop.game.records.Move;
import org.toop.game.reversi.Reversi;
import org.toop.game.reversi.ReversiAI;
import org.toop.game.reversi.ReversiAIML;
import org.toop.game.reversi.ReversiAISimple;
import java.io.File;
@@ -26,14 +27,17 @@ import java.util.ArrayList;
import java.util.List;
import static java.lang.Math.abs;
import static java.lang.Math.random;
public class NeuralNetwork {
private MultiLayerConfiguration conf;
private MultiLayerNetwork model;
private AI<Reversi> opponentAI;
private ReversiAI reversiAI = new ReversiAI();
private AI<Reversi> opponentRand = new ReversiAI();
private AI<Reversi> opponentSimple = new ReversiAISimple();
private AI<Reversi> opponentAIML = new ReversiAIML();
public NeuralNetwork() {}
@@ -61,7 +65,7 @@ public class NeuralNetwork {
model.init();
IO.println(model.summary());
model.setLearningRate(0.001);
model.setLearningRate(0.0003);
trainingLoop();
saveModel();
}
@@ -85,13 +89,15 @@ public class NeuralNetwork {
}
public void trainingLoop(){
int totalGames = 5000;
double epsilon = 0.1;
int totalGames = 50000;
double epsilon = 0.05;
long start = System.nanoTime();
for (int game = 0; game<totalGames; game++){
char modelPlayer = random()<0.5?'B':'W';
Reversi reversi = new Reversi();
opponentAI = getOpponentAI();
List<StateAction> gameHistory = new ArrayList<>();
GameState state = GameState.NORMAL;
@@ -100,7 +106,7 @@ public class NeuralNetwork {
while (state != GameState.DRAW && state != GameState.WIN){
char curr = reversi.getCurrentPlayer();
Move move;
if (curr == 'B') {
if (curr == modelPlayer) {
int[] input = reversi.getBoardInt();
if (Math.random() < epsilon) {
Move[] moves = reversi.getLegalMoves();
@@ -114,7 +120,7 @@ public class NeuralNetwork {
move = new Move(location, reversi.getCurrentPlayer());
}
}else{
move = reversiAI.findBestMove(reversi,5);
move = opponentAI.findBestMove(reversi,5);
}
state = reversi.play(move);
}
@@ -130,6 +136,10 @@ public class NeuralNetwork {
reward = 0;
}
if (modelPlayer == 'W'){
reward = -reward;
}
for (StateAction step : gameHistory){
trainFromHistory(step, reward);
@@ -167,6 +177,15 @@ public class NeuralNetwork {
return bestMove;
}
private AI<Reversi> getOpponentAI(){
return switch ((int) (Math.random() * 4)) {
case 0 -> opponentRand;
case 1 -> opponentSimple;
case 2 -> opponentAIML;
default -> opponentRand;
};
}
private void trainFromHistory(StateAction step, double reward){
double[] output = new double[64];
output[step.action] = reward;