mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
added some useful testing methods.
made training slightly better.
This commit is contained in:
@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
|
||||
import org.toop.game.records.Move;
|
||||
import org.toop.game.reversi.Reversi;
|
||||
import org.toop.game.reversi.ReversiAI;
|
||||
import org.toop.game.reversi.ReversiAIML;
|
||||
import org.toop.game.reversi.ReversiAISimple;
|
||||
|
||||
import java.io.File;
|
||||
@@ -26,14 +27,17 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static java.lang.Math.abs;
|
||||
import static java.lang.Math.random;
|
||||
|
||||
public class NeuralNetwork {
|
||||
|
||||
private MultiLayerConfiguration conf;
|
||||
private MultiLayerNetwork model;
|
||||
private AI<Reversi> opponentAI;
|
||||
private ReversiAI reversiAI = new ReversiAI();
|
||||
private AI<Reversi> opponentRand = new ReversiAI();
|
||||
private AI<Reversi> opponentSimple = new ReversiAISimple();
|
||||
private AI<Reversi> opponentAIML = new ReversiAIML();
|
||||
|
||||
|
||||
public NeuralNetwork() {}
|
||||
@@ -61,7 +65,7 @@ public class NeuralNetwork {
|
||||
model.init();
|
||||
IO.println(model.summary());
|
||||
|
||||
model.setLearningRate(0.001);
|
||||
model.setLearningRate(0.0003);
|
||||
trainingLoop();
|
||||
saveModel();
|
||||
}
|
||||
@@ -85,13 +89,15 @@ public class NeuralNetwork {
|
||||
}
|
||||
|
||||
public void trainingLoop(){
|
||||
int totalGames = 5000;
|
||||
double epsilon = 0.1;
|
||||
int totalGames = 50000;
|
||||
double epsilon = 0.05;
|
||||
|
||||
long start = System.nanoTime();
|
||||
|
||||
for (int game = 0; game<totalGames; game++){
|
||||
char modelPlayer = random()<0.5?'B':'W';
|
||||
Reversi reversi = new Reversi();
|
||||
opponentAI = getOpponentAI();
|
||||
List<StateAction> gameHistory = new ArrayList<>();
|
||||
GameState state = GameState.NORMAL;
|
||||
|
||||
@@ -100,7 +106,7 @@ public class NeuralNetwork {
|
||||
while (state != GameState.DRAW && state != GameState.WIN){
|
||||
char curr = reversi.getCurrentPlayer();
|
||||
Move move;
|
||||
if (curr == 'B') {
|
||||
if (curr == modelPlayer) {
|
||||
int[] input = reversi.getBoardInt();
|
||||
if (Math.random() < epsilon) {
|
||||
Move[] moves = reversi.getLegalMoves();
|
||||
@@ -114,7 +120,7 @@ public class NeuralNetwork {
|
||||
move = new Move(location, reversi.getCurrentPlayer());
|
||||
}
|
||||
}else{
|
||||
move = reversiAI.findBestMove(reversi,5);
|
||||
move = opponentAI.findBestMove(reversi,5);
|
||||
}
|
||||
state = reversi.play(move);
|
||||
}
|
||||
@@ -130,6 +136,10 @@ public class NeuralNetwork {
|
||||
reward = 0;
|
||||
}
|
||||
|
||||
if (modelPlayer == 'W'){
|
||||
reward = -reward;
|
||||
}
|
||||
|
||||
|
||||
for (StateAction step : gameHistory){
|
||||
trainFromHistory(step, reward);
|
||||
@@ -167,6 +177,15 @@ public class NeuralNetwork {
|
||||
return bestMove;
|
||||
}
|
||||
|
||||
private AI<Reversi> getOpponentAI(){
|
||||
return switch ((int) (Math.random() * 4)) {
|
||||
case 0 -> opponentRand;
|
||||
case 1 -> opponentSimple;
|
||||
case 2 -> opponentAIML;
|
||||
default -> opponentRand;
|
||||
};
|
||||
}
|
||||
|
||||
private void trainFromHistory(StateAction step, double reward){
|
||||
double[] output = new double[64];
|
||||
output[step.action] = reward;
|
||||
|
||||
Reference in New Issue
Block a user