6 Commits

Author SHA1 Message Date
lieght
22a73fc50a Moved back in threads 2026-01-23 19:16:29 +01:00
lieght
af9a316639 Merge remote-tracking branch 'origin/Development' into Development 2026-01-23 19:15:55 +01:00
lieght
8508377cb4 Parameters added to tests 2026-01-23 19:15:29 +01:00
ramollia
97276c7e80 readded threads argument 2026-01-23 19:14:46 +01:00
ramollia
11eda3c8b5 Merge remote-tracking branch 'origin/Development' into Development
# Conflicts:
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java
#	game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java
2026-01-23 19:09:51 +01:00
ramollia
039c0393c8 fixed extra wait time for threads 2026-01-23 19:09:28 +01:00
3 changed files with 70 additions and 63 deletions

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -11,22 +12,18 @@ public class MCTSAI3 extends MCTSAI {
private final int threads; private final int threads;
private final ExecutorService threadPool; private final ExecutorService threadPool;
public MCTSAI3(int milliseconds) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(milliseconds);
}
public MCTSAI3(int milliseconds, int threads) { public MCTSAI3(int milliseconds, int threads) {
this.threads = threads;
threadPool = Executors.newFixedThreadPool(threads);
super(milliseconds); super(milliseconds);
this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads);
} }
public MCTSAI3(MCTSAI3 other) { public MCTSAI3(MCTSAI3 other) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(other); super(other);
this.threads = other.threads;
this.threadPool = other.threadPool;
} }
@Override @Override
@@ -40,12 +37,21 @@ public class MCTSAI3 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads);
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> iterate(root, endTime)); threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
} }
try { try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); lastIterations = root.visits.get();
@@ -59,14 +65,12 @@ public class MCTSAI3 extends MCTSAI {
} }
} }
private Void iterate(Node root, long endTime) { private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
return null;
} }
} }

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -13,25 +14,18 @@ public class MCTSAI4 extends MCTSAI {
private Node root; private Node root;
public MCTSAI4(int milliseconds) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(milliseconds);
this.root = null;
}
public MCTSAI4(int milliseconds, int threads) { public MCTSAI4(int milliseconds, int threads) {
this.threads = threads;
threadPool = Executors.newFixedThreadPool(threads);
super(milliseconds); super(milliseconds);
this.root = null;
this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads);
} }
public MCTSAI4(MCTSAI4 other) { public MCTSAI4(MCTSAI4 other) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(other); super(other);
this.root = other.root;
this.threads = other.threads;
this.threadPool = other.threadPool;
} }
@Override @Override
@@ -45,12 +39,21 @@ public class MCTSAI4 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads);
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> iterate(root, endTime)); threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
} }
try { try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); lastIterations = root.visits.get();
@@ -68,14 +71,12 @@ public class MCTSAI4 extends MCTSAI {
} }
} }
private Void iterate(Node root, long endTime) { private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
return null;
} }
} }

View File

@@ -30,44 +30,20 @@ import java.util.List;
public class AITest { public class AITest {
private static String fileName = "gameDataThreads.csv"; private static String fileName = "gameData.csv";
private static List<Matchup> matchupList = new ArrayList<Matchup>(); private static List<Matchup> matchupList = new ArrayList<Matchup>();
private static List<AIData> dataList = new ArrayList<AIData>(); private static List<AIData> dataList = new ArrayList<AIData>();
private static List<GameData> gameDataList = new ArrayList<GameData>(); private static List<GameData> gameDataList = new ArrayList<GameData>();
// @BeforeAll
// public static void init() {
//
// var versions = new ArtificialPlayer[4];
// versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
// versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
// versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3");
// versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4");
//
// for (int i = 0; i < versions.length; i++) {
// for (int j = i + 1; j < versions.length; j++) {
// final int playerIndex1 = i % versions.length;
// final int playerIndex2 = j % versions.length;
// addMatch(versions[playerIndex1], versions[playerIndex2]);
// addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
// }
// }
// }
@BeforeAll @BeforeAll
public static void init() { public static void init() {
var versions = new ArtificialPlayer[9]; var versions = new ArtificialPlayer[4];
versions[0] = new ArtificialPlayer(new MCTSAI3(10, 1), "MCTS V3T1"); versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI3(10, 2), "MCTS V3T2"); versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(10, 4), "MCTS V3T4"); versions[2] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3T8"); versions[3] = new ArtificialPlayer(new MCTSAI4(10, 8), "MCTS V4");
versions[4] = new ArtificialPlayer(new MCTSAI3(10, 16), "MCTS V3T16");
versions[5] = new ArtificialPlayer(new MCTSAI3(10, 128), "MCTS V3T128");
versions[6] = new ArtificialPlayer(new MCTSAI3(10, 256), "MCTS V3T256");
versions[7] = new ArtificialPlayer(new MCTSAI3(10, 512), "MCTS V3T512");
versions[8] = new ArtificialPlayer(new MCTSAI3(10, 1024), "MCTS V3T1024");
for (int i = 0; i < versions.length; i++) { for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) { for (int j = i + 1; j < versions.length; j++) {
@@ -79,6 +55,32 @@ public class AITest {
} }
} }
// @BeforeAll
// public static void init() {
//
// var versions = new ArtificialPlayer[11];
// versions[0] = new ArtificialPlayer(new MCTSAI3(10, 1), "MCTS V3T1");
// versions[1] = new ArtificialPlayer(new MCTSAI3(10, 2), "MCTS V3T2");
// versions[2] = new ArtificialPlayer(new MCTSAI3(10, 4), "MCTS V3T4");
// versions[3] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3T8");
// versions[4] = new ArtificialPlayer(new MCTSAI3(10, 16), "MCTS V3T16");
// versions[5] = new ArtificialPlayer(new MCTSAI3(10, 128), "MCTS V3T32");
// versions[6] = new ArtificialPlayer(new MCTSAI3(10, 256), "MCTS V3T64");
// versions[7] = new ArtificialPlayer(new MCTSAI3(10, 128), "MCTS V3T128");
// versions[8] = new ArtificialPlayer(new MCTSAI3(10, 256), "MCTS V3T256");
// versions[9] = new ArtificialPlayer(new MCTSAI3(10, 512), "MCTS V3T512");
// versions[10] = new ArtificialPlayer(new MCTSAI3(10, 1024), "MCTS V3T1024");
//
// for (int i = 0; i < versions.length; i++) {
// for (int j = i + 1; j < versions.length; j++) {
// final int playerIndex1 = i % versions.length;
// final int playerIndex2 = j % versions.length;
// addMatch(versions[playerIndex1], versions[playerIndex2]);
// addMatch(versions[playerIndex2], versions[playerIndex1]); // home vs away system
// }
// }
// }
public static void addMatch(ArtificialPlayer v1, ArtificialPlayer v2) { public static void addMatch(ArtificialPlayer v1, ArtificialPlayer v2) {
matchupList.add(new Matchup(v1, v2)); matchupList.add(new Matchup(v1, v2));
} }