mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
added some useful testing methods.
made training slightly better.
This commit is contained in:
@@ -38,8 +38,8 @@ public final class Main {
|
|||||||
}).start();
|
}).start();
|
||||||
|
|
||||||
|
|
||||||
//NeuralNetwork nn = new NeuralNetwork();
|
NeuralNetwork nn = new NeuralNetwork();
|
||||||
//nn.init();
|
nn.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import org.toop.game.enumerators.GameState;
|
|||||||
import org.toop.game.records.Move;
|
import org.toop.game.records.Move;
|
||||||
import org.toop.game.reversi.Reversi;
|
import org.toop.game.reversi.Reversi;
|
||||||
import org.toop.game.reversi.ReversiAI;
|
import org.toop.game.reversi.ReversiAI;
|
||||||
|
import org.toop.game.reversi.ReversiAIML;
|
||||||
import org.toop.game.reversi.ReversiAISimple;
|
import org.toop.game.reversi.ReversiAISimple;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
@@ -26,14 +27,17 @@ import java.util.ArrayList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static java.lang.Math.abs;
|
import static java.lang.Math.abs;
|
||||||
|
import static java.lang.Math.random;
|
||||||
|
|
||||||
public class NeuralNetwork {
|
public class NeuralNetwork {
|
||||||
|
|
||||||
private MultiLayerConfiguration conf;
|
private MultiLayerConfiguration conf;
|
||||||
private MultiLayerNetwork model;
|
private MultiLayerNetwork model;
|
||||||
|
private AI<Reversi> opponentAI;
|
||||||
private ReversiAI reversiAI = new ReversiAI();
|
private ReversiAI reversiAI = new ReversiAI();
|
||||||
private AI<Reversi> opponentRand = new ReversiAI();
|
private AI<Reversi> opponentRand = new ReversiAI();
|
||||||
private AI<Reversi> opponentSimple = new ReversiAISimple();
|
private AI<Reversi> opponentSimple = new ReversiAISimple();
|
||||||
|
private AI<Reversi> opponentAIML = new ReversiAIML();
|
||||||
|
|
||||||
|
|
||||||
public NeuralNetwork() {}
|
public NeuralNetwork() {}
|
||||||
@@ -61,7 +65,7 @@ public class NeuralNetwork {
|
|||||||
model.init();
|
model.init();
|
||||||
IO.println(model.summary());
|
IO.println(model.summary());
|
||||||
|
|
||||||
model.setLearningRate(0.001);
|
model.setLearningRate(0.0003);
|
||||||
trainingLoop();
|
trainingLoop();
|
||||||
saveModel();
|
saveModel();
|
||||||
}
|
}
|
||||||
@@ -85,13 +89,15 @@ public class NeuralNetwork {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void trainingLoop(){
|
public void trainingLoop(){
|
||||||
int totalGames = 5000;
|
int totalGames = 50000;
|
||||||
double epsilon = 0.1;
|
double epsilon = 0.05;
|
||||||
|
|
||||||
long start = System.nanoTime();
|
long start = System.nanoTime();
|
||||||
|
|
||||||
for (int game = 0; game<totalGames; game++){
|
for (int game = 0; game<totalGames; game++){
|
||||||
|
char modelPlayer = random()<0.5?'B':'W';
|
||||||
Reversi reversi = new Reversi();
|
Reversi reversi = new Reversi();
|
||||||
|
opponentAI = getOpponentAI();
|
||||||
List<StateAction> gameHistory = new ArrayList<>();
|
List<StateAction> gameHistory = new ArrayList<>();
|
||||||
GameState state = GameState.NORMAL;
|
GameState state = GameState.NORMAL;
|
||||||
|
|
||||||
@@ -100,7 +106,7 @@ public class NeuralNetwork {
|
|||||||
while (state != GameState.DRAW && state != GameState.WIN){
|
while (state != GameState.DRAW && state != GameState.WIN){
|
||||||
char curr = reversi.getCurrentPlayer();
|
char curr = reversi.getCurrentPlayer();
|
||||||
Move move;
|
Move move;
|
||||||
if (curr == 'B') {
|
if (curr == modelPlayer) {
|
||||||
int[] input = reversi.getBoardInt();
|
int[] input = reversi.getBoardInt();
|
||||||
if (Math.random() < epsilon) {
|
if (Math.random() < epsilon) {
|
||||||
Move[] moves = reversi.getLegalMoves();
|
Move[] moves = reversi.getLegalMoves();
|
||||||
@@ -114,7 +120,7 @@ public class NeuralNetwork {
|
|||||||
move = new Move(location, reversi.getCurrentPlayer());
|
move = new Move(location, reversi.getCurrentPlayer());
|
||||||
}
|
}
|
||||||
}else{
|
}else{
|
||||||
move = reversiAI.findBestMove(reversi,5);
|
move = opponentAI.findBestMove(reversi,5);
|
||||||
}
|
}
|
||||||
state = reversi.play(move);
|
state = reversi.play(move);
|
||||||
}
|
}
|
||||||
@@ -130,6 +136,10 @@ public class NeuralNetwork {
|
|||||||
reward = 0;
|
reward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (modelPlayer == 'W'){
|
||||||
|
reward = -reward;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
for (StateAction step : gameHistory){
|
for (StateAction step : gameHistory){
|
||||||
trainFromHistory(step, reward);
|
trainFromHistory(step, reward);
|
||||||
@@ -167,6 +177,15 @@ public class NeuralNetwork {
|
|||||||
return bestMove;
|
return bestMove;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AI<Reversi> getOpponentAI(){
|
||||||
|
return switch ((int) (Math.random() * 4)) {
|
||||||
|
case 0 -> opponentRand;
|
||||||
|
case 1 -> opponentSimple;
|
||||||
|
case 2 -> opponentAIML;
|
||||||
|
default -> opponentRand;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private void trainFromHistory(StateAction step, double reward){
|
private void trainFromHistory(StateAction step, double reward){
|
||||||
double[] output = new double[64];
|
double[] output = new double[64];
|
||||||
output[step.action] = reward;
|
output[step.action] = reward;
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import org.toop.game.records.Move;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
|
import static java.lang.Math.random;
|
||||||
|
|
||||||
public class ReversiAIML extends AI<Reversi>{
|
public class ReversiAIML extends AI<Reversi>{
|
||||||
|
|
||||||
MultiLayerNetwork model;
|
MultiLayerNetwork model;
|
||||||
@@ -28,23 +30,27 @@ public class ReversiAIML extends AI<Reversi>{
|
|||||||
INDArray boardInput = Nd4j.create(new int[][] { input });
|
INDArray boardInput = Nd4j.create(new int[][] { input });
|
||||||
INDArray prediction = model.output(boardInput);
|
INDArray prediction = model.output(boardInput);
|
||||||
|
|
||||||
int move = pickLegalMove(prediction, reversi);
|
int move = pickLegalMove(prediction,reversi);
|
||||||
return new Move(move, reversi.getCurrentPlayer());
|
return new Move(move, reversi.getCurrentPlayer());
|
||||||
}
|
}
|
||||||
|
|
||||||
private int pickLegalMove(INDArray prediction, Reversi reversi){
|
private int pickLegalMove(INDArray prediction, Reversi reversi) {
|
||||||
double[] probs = prediction.toDoubleVector();
|
double[] logits = prediction.toDoubleVector();
|
||||||
Move[] legalMoves = reversi.getLegalMoves();
|
Move[] legalMoves = reversi.getLegalMoves();
|
||||||
|
|
||||||
if (legalMoves.length == 0) return -1;
|
if (legalMoves.length == 0) return -1;
|
||||||
|
|
||||||
int bestMove = legalMoves[0].position();
|
int bestMove = legalMoves[0].position();
|
||||||
double bestVal = probs[bestMove];
|
double bestVal = logits[bestMove];
|
||||||
|
|
||||||
for (Move move : legalMoves){
|
if (random() < 0.01){
|
||||||
if (probs[move.position()] > bestVal){
|
return legalMoves[(int)(random()*legalMoves.length-.5)].position();
|
||||||
bestMove = move.position();
|
}
|
||||||
bestVal = probs[bestMove];
|
for (Move move : legalMoves) {
|
||||||
|
int pos = move.position();
|
||||||
|
if (logits[pos] > bestVal) {
|
||||||
|
bestMove = pos;
|
||||||
|
bestVal = logits[pos];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bestMove;
|
return bestMove;
|
||||||
|
|||||||
@@ -33,6 +33,12 @@ public class ReversiAISimple extends AI<Reversi> {
|
|||||||
bestScore = numSco;
|
bestScore = numSco;
|
||||||
bestMoveScore = move;
|
bestMoveScore = move;
|
||||||
}
|
}
|
||||||
|
if (numSco == bestScore || numOpt == bestOptions) {
|
||||||
|
if (Math.random() < 0.5) {
|
||||||
|
bestMoveOptions = move;
|
||||||
|
bestMoveScore = move;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
|
//IO.println("Move: " + move.position() + ". Options: " + numOpt + ". Score: " + numSco);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import java.util.*;
|
|||||||
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.toop.game.AI;
|
||||||
import org.toop.game.enumerators.GameState;
|
import org.toop.game.enumerators.GameState;
|
||||||
import org.toop.game.records.Move;
|
import org.toop.game.records.Move;
|
||||||
import org.toop.game.reversi.Reversi;
|
import org.toop.game.reversi.Reversi;
|
||||||
@@ -18,6 +19,8 @@ class ReversiTest {
|
|||||||
private ReversiAI ai;
|
private ReversiAI ai;
|
||||||
private ReversiAIML aiml;
|
private ReversiAIML aiml;
|
||||||
private ReversiAISimple aiSimple;
|
private ReversiAISimple aiSimple;
|
||||||
|
private AI<Reversi> player1;
|
||||||
|
private AI<Reversi> player2;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setup() {
|
void setup() {
|
||||||
@@ -32,20 +35,20 @@ class ReversiTest {
|
|||||||
@Test
|
@Test
|
||||||
void testCorrectStartPiecesPlaced() {
|
void testCorrectStartPiecesPlaced() {
|
||||||
assertNotNull(game);
|
assertNotNull(game);
|
||||||
assertEquals('W',game.getBoard()[27]);
|
assertEquals('W', game.getBoard()[27]);
|
||||||
assertEquals('B',game.getBoard()[28]);
|
assertEquals('B', game.getBoard()[28]);
|
||||||
assertEquals('B',game.getBoard()[35]);
|
assertEquals('B', game.getBoard()[35]);
|
||||||
assertEquals('W',game.getBoard()[36]);
|
assertEquals('W', game.getBoard()[36]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testGetLegalMovesAtStart() {
|
void testGetLegalMovesAtStart() {
|
||||||
Move[] moves = game.getLegalMoves();
|
Move[] moves = game.getLegalMoves();
|
||||||
List<Move> expectedMoves = List.of(
|
List<Move> expectedMoves = List.of(
|
||||||
new Move(19,'B'),
|
new Move(19, 'B'),
|
||||||
new Move(26,'B'),
|
new Move(26, 'B'),
|
||||||
new Move(37,'B'),
|
new Move(37, 'B'),
|
||||||
new Move(44,'B')
|
new Move(44, 'B')
|
||||||
);
|
);
|
||||||
assertNotNull(moves);
|
assertNotNull(moves);
|
||||||
assertTrue(moves.length > 0);
|
assertTrue(moves.length > 0);
|
||||||
@@ -83,12 +86,12 @@ class ReversiTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testCountScoreCorrectlyAtStart() {
|
void testCountScoreCorrectlyAtStart() {
|
||||||
long start = System.nanoTime();
|
long start = System.nanoTime();
|
||||||
Reversi.Score score = game.getScore();
|
Reversi.Score score = game.getScore();
|
||||||
assertEquals(2, score.player1Score()); // Black
|
assertEquals(2, score.player1Score()); // Black
|
||||||
assertEquals(2, score.player2Score()); // White
|
assertEquals(2, score.player2Score()); // White
|
||||||
long end = System.nanoTime();
|
long end = System.nanoTime();
|
||||||
IO.println((end-start));
|
IO.println((end - start));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -97,7 +100,7 @@ class ReversiTest {
|
|||||||
game.play(new Move(20, 'W'));
|
game.play(new Move(20, 'W'));
|
||||||
Move[] moves = game.getLegalMoves();
|
Move[] moves = game.getLegalMoves();
|
||||||
List<Move> expectedMoves = List.of(
|
List<Move> expectedMoves = List.of(
|
||||||
new Move(13,'B'),
|
new Move(13, 'B'),
|
||||||
new Move(21, 'B'),
|
new Move(21, 'B'),
|
||||||
new Move(29, 'B'),
|
new Move(29, 'B'),
|
||||||
new Move(37, 'B'),
|
new Move(37, 'B'),
|
||||||
@@ -110,11 +113,11 @@ class ReversiTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testCountScoreCorrectlyAtEnd() {
|
void testCountScoreCorrectlyAtEnd() {
|
||||||
for (int i = 0; i < 1; i++){
|
for (int i = 0; i < 1; i++) {
|
||||||
game = new Reversi();
|
game = new Reversi();
|
||||||
Move[] legalMoves = game.getLegalMoves();
|
Move[] legalMoves = game.getLegalMoves();
|
||||||
while(legalMoves.length > 0) {
|
while (legalMoves.length > 0) {
|
||||||
game.play(legalMoves[(int)(Math.random()*legalMoves.length)]);
|
game.play(legalMoves[(int) (Math.random() * legalMoves.length)]);
|
||||||
legalMoves = game.getLegalMoves();
|
legalMoves = game.getLegalMoves();
|
||||||
}
|
}
|
||||||
Reversi.Score score = game.getScore();
|
Reversi.Score score = game.getScore();
|
||||||
@@ -152,8 +155,8 @@ class ReversiTest {
|
|||||||
game.play(new Move(60, 'W'));
|
game.play(new Move(60, 'W'));
|
||||||
game.play(new Move(59, 'B'));
|
game.play(new Move(59, 'B'));
|
||||||
assertEquals('B', game.getCurrentPlayer());
|
assertEquals('B', game.getCurrentPlayer());
|
||||||
game.play(ai.findBestMove(game,5));
|
game.play(ai.findBestMove(game, 5));
|
||||||
game.play(ai.findBestMove(game,5));
|
game.play(ai.findBestMove(game, 5));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -186,9 +189,9 @@ class ReversiTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAISelectsLegalMove() {
|
void testAISelectsLegalMove() {
|
||||||
Move move = ai.findBestMove(game,4);
|
Move move = ai.findBestMove(game, 4);
|
||||||
assertNotNull(move);
|
assertNotNull(move);
|
||||||
assertTrue(containsMove(game.getLegalMoves(),move), "AI should always choose a legal move");
|
assertTrue(containsMove(game.getLegalMoves(), move), "AI should always choose a legal move");
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean containsMove(Move[] moves, Move move) {
|
private boolean containsMove(Move[] moves, Move move) {
|
||||||
@@ -199,33 +202,60 @@ class ReversiTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testAIvsAIML(){
|
void testAis() {
|
||||||
IO.println("Testing AI simple ...");
|
player1 = aiml;
|
||||||
int totalGames = 5000;
|
player2 = ai;
|
||||||
|
testAIvsAIML();
|
||||||
|
player2 = aiSimple;
|
||||||
|
testAIvsAIML();
|
||||||
|
player1 = ai;
|
||||||
|
testAIvsAIML();
|
||||||
|
player2 = aiml;
|
||||||
|
testAIvsAIML();
|
||||||
|
player1 = aiml;
|
||||||
|
testAIvsAIML();
|
||||||
|
player1 = aiSimple;
|
||||||
|
testAIvsAIML();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testAIvsAIML() {
|
||||||
|
if(player1 == null || player2 == null) {
|
||||||
|
player1 = aiml;
|
||||||
|
player2 = ai;
|
||||||
|
}
|
||||||
|
int totalGames = 2000;
|
||||||
|
IO.println("Testing... " + player1.getClass().getSimpleName() + " vs " + player2.getClass().getSimpleName() + " for " + totalGames + " games");
|
||||||
int p1wins = 0;
|
int p1wins = 0;
|
||||||
int p2wins = 0;
|
int p2wins = 0;
|
||||||
int draws = 0;
|
int draws = 0;
|
||||||
|
List<Integer> moves = new ArrayList<>();
|
||||||
for (int i = 0; i < totalGames; i++) {
|
for (int i = 0; i < totalGames; i++) {
|
||||||
game = new Reversi();
|
game = new Reversi();
|
||||||
while (!game.isGameOver()) {
|
while (!game.isGameOver()) {
|
||||||
char curr = game.getCurrentPlayer();
|
char curr = game.getCurrentPlayer();
|
||||||
if (curr == 'W') {
|
Move move;
|
||||||
game.play(ai.findBestMove(game,5));
|
if (curr == 'B') {
|
||||||
}
|
move = player1.findBestMove(game, 5);
|
||||||
else {
|
} else {
|
||||||
game.play(ai.findBestMove(game,5));
|
move = player2.findBestMove(game, 5);
|
||||||
}
|
}
|
||||||
|
if (i%500 == 0) moves.add(move.position());
|
||||||
|
game.play(move);
|
||||||
|
}
|
||||||
|
if (i%500 == 0) {
|
||||||
|
IO.println(moves);
|
||||||
|
moves.clear();
|
||||||
}
|
}
|
||||||
int winner = game.getWinner();
|
int winner = game.getWinner();
|
||||||
if (winner == 1) {
|
if (winner == 1) {
|
||||||
p1wins++;
|
p1wins++;
|
||||||
}else if (winner == 2) {
|
} else if (winner == 2) {
|
||||||
p2wins++;
|
p2wins++;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
draws++;
|
draws++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double)p1wins/totalGames + "\np2wins: " + p2wins + " draws: " + draws);
|
IO.println("p1 winrate: " + p1wins + "/" + totalGames + " = " + (double) p1wins / totalGames + "\np2wins: " + p2wins + " draws: " + draws);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user