Better data collection for overnight run

This commit is contained in:
lieght
2026-01-21 23:43:11 +01:00
parent 5d2fff7ae7
commit 992523b936
2 changed files with 137 additions and 44 deletions

View File

@@ -1,6 +1,7 @@
package research;
import org.apache.maven.surefire.shared.io.FileDeleteStrategy;
import org.junit.jupiter.api.*;
import org.toop.framework.game.games.reversi.BitboardReversi;
import org.toop.framework.game.players.ArtificialPlayer;
@@ -10,10 +11,22 @@ import org.toop.game.players.ai.mcts.MCTSAI2;
import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.text.DecimalFormat;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class AITest {
@@ -26,16 +39,19 @@ public class AITest {
public static void init() {
var versions = new ArtificialPlayer[4];
versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
addMatch(versions[playerIndex1], versions[playerIndex2]);
addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
versions[0] = new ArtificialPlayer(new MCTSAI1(5), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI2(5), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(5), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI4(5), "MCTS V4");
for (int k = 0; k < 50; k++) {
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
final int playerIndex2 = j % versions.length;
addMatch(versions[playerIndex1], versions[playerIndex2]);
addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
}
}
}
}
@@ -55,9 +71,7 @@ public class AITest {
@Test
public void testAIvsAI() {
for (Matchup m : matchupList) {
for (int i = 0; i < 6; i++) {
playGame(m);
}
playGame(m);
}
}
@@ -88,19 +102,57 @@ public class AITest {
}
match.play(move);
}
generateMatchData(m.getPlayer1().getName(), m.getPlayer2().getName(), match, millisecondscounterAI1, millisecondscounterAI2);
generateData(m, match, iterationsAI1, iterationsAI2);
generateMatchData(m.getPlayer1().getName(), m.getPlayer2().getName(), match, iterationsAI1, iterationsAI2, millisecondscounterAI1, millisecondscounterAI2);
}
public void generateMatchData(String AI1, String AI2, BitboardReversi match, long millisecondscounterAI1, long millisecondscounterAI2) {
addGameData(new GameData(
AI1,
AI2,
getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
millisecondscounterAI1,
millisecondscounterAI2
));
public void generateMatchData(
String AI1,
String AI2,
BitboardReversi match,
List<Integer> iterationsAI1,
List<Integer> iterationsAI2,
long millisecondscounterAI1,
long millisecondscounterAI2
) {
try {
var ai110 = iterationsAI1.subList(0, 9);
var ai120 = iterationsAI1.subList(10, 19);
var ai130 = iterationsAI1.subList(20, iterationsAI1.size());
var ai210 = iterationsAI2.subList(0, 9);
var ai220 = iterationsAI2.subList(10, 19);
var ai230 = iterationsAI2.subList(20, iterationsAI2.size());
writeGamesToCSV("gameData.csv", new GameData(
AI1,
AI2,
getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
iterationsAI1.stream().mapToInt(Integer::intValue).sum(),
iterationsAI1.stream().mapToDouble(Integer::doubleValue).sum() / iterationsAI1.size(),
ai110.stream().mapToDouble(Integer::doubleValue).sum() / ai110.size(),
ai120.stream().mapToDouble(Integer::doubleValue).sum() / ai120.size(),
ai130.stream().mapToDouble(Integer::doubleValue).sum() / ai130.size(),
iterationsAI2.stream().mapToInt(Integer::intValue).sum(),
iterationsAI2.stream().mapToDouble(Integer::doubleValue).sum() / iterationsAI2.size(),
ai210.stream().mapToDouble(Integer::doubleValue).sum() / ai210.size(),
ai220.stream().mapToDouble(Integer::doubleValue).sum() / ai220.size(),
ai230.stream().mapToDouble(Integer::doubleValue).sum() / ai230.size(),
millisecondscounterAI1,
millisecondscounterAI2,
LocalTime.now().toString()
));
} catch (IOException e) {
throw new RuntimeException(e);
} catch (IndexOutOfBoundsException e) {
return;
}
}
public String getWinnerForMatch(String AI1, String AI2, BitboardReversi match) {
@@ -186,26 +238,45 @@ public class AITest {
public static void writeAfterTests() {
try {
writeAIToCsv("Data.csv", dataList);
writeGamesToCSV("Games.csv", gameDataList);
} catch (IOException e) {
e.printStackTrace();
}
}
public static void writeGamesToCSV(String filepath, List<GameData> gameDataList) throws IOException {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filepath))) {
writer.write("Black,White,Winner,Turns Played,Total Time AI1, Total Time AI2");
writer.newLine();
for (GameData gameData : gameDataList) {
writer.write(
gameData.AI1() + "," +
gameData.AI2() + "," +
gameData.winner() + "," +
gameData.turns() + "," +
(gameData.millisecondsAI1() / 1_000_000L) + "," +
(gameData.millisecondsAI2() / 1_000_000L));
public static void writeGamesToCSV(String filepath, GameData gameData) throws IOException {
try (
final BufferedWriter writer = Files.newBufferedWriter(
Paths.get(filepath),
StandardCharsets.UTF_8,
StandardOpenOption.CREATE,
StandardOpenOption.APPEND
);
final BufferedReader reader = new BufferedReader(new FileReader(filepath))
) {
if (reader.readLine() == null || reader.readLine().isBlank()) {
writer.write("Black,White,Winner,Turns Played,Black iterations,Black average iterations,Black average iterations 0-10,Black average iterations 11-20,Black average iterations 21-30,White iterations,White average iterations,White average iterations 0-10,White average iterations 11-20,White average iterations 21-30,Total Time AI1,Total Time AI2,Time");
writer.newLine();
}
writer.write(
gameData.AI1() + "," +
gameData.AI2() + "," +
gameData.winner() + "," +
gameData.turns() + "," +
gameData.AI1totalIterations() + "," +
BigDecimal.valueOf(gameData.AI1averageIterations()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations10()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations20()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI1averageIterations30()).setScale(2, RoundingMode.DOWN) + "," +
gameData.AI2totalIterations() + "," +
BigDecimal.valueOf(gameData.AI2averageIterations()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations10()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations20()).setScale(2, RoundingMode.DOWN) + "," +
BigDecimal.valueOf(gameData.AI2averageIterations30()).setScale(2, RoundingMode.DOWN) + "," +
(gameData.millisecondsAI1() / 1_000_000L) + "," +
(gameData.millisecondsAI2() / 1_000_000L) + "," +
gameData.time());
writer.newLine();
}
}
@@ -216,12 +287,12 @@ public class AITest {
for (AIData data : dataList) {
writer.write(
data.getAI() + "," +
data.getGamesPlayed() + "," +
data.getWinrate() + "," +
Math.round(data.getAverageIterations()) + "," +
Math.round(data.getAverageIterations10()) + "," +
Math.round(data.getAverageIterations20()) + "," +
Math.round(data.getAverageIterations30()));
data.getGamesPlayed() + "," +
data.getWinrate() + "," +
Math.round(data.getAverageIterations()) + "," +
Math.round(data.getAverageIterations10()) + "," +
Math.round(data.getAverageIterations20()) + "," +
Math.round(data.getAverageIterations30()));
writer.newLine();
}
}

View File

@@ -1,4 +1,26 @@
package research;
public record GameData(String AI1, String AI2, String winner, int turns, long millisecondsAI1, long millisecondsAI2) {}
public record GameData(
String AI1,
String AI2,
String winner,
int turns,
int AI1totalIterations,
double AI1averageIterations,
double AI1averageIterations10,
double AI1averageIterations20,
double AI1averageIterations30,
int AI2totalIterations,
double AI2averageIterations,
double AI2averageIterations10,
double AI2averageIterations20,
double AI2averageIterations30,
long millisecondsAI1,
long millisecondsAI2,
String time
) {}