fixed extra wait time for threads

This commit is contained in:
ramollia
2026-01-23 19:09:28 +01:00
parent b5bd4adf91
commit 039c0393c8
2 changed files with 26 additions and 10 deletions

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@@ -31,12 +32,21 @@ public class MCTSAI3 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(THREADS);
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
}
try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS);
final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get();
@@ -50,14 +60,12 @@ public class MCTSAI3 extends MCTSAI {
}
}
private Void iterate(Node root, long endTime) {
private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@@ -37,12 +38,21 @@ public class MCTSAI4 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(THREADS);
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
}
try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS);
final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get();
@@ -60,14 +70,12 @@ public class MCTSAI4 extends MCTSAI {
}
}
private Void iterate(Node root, long endTime) {
private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}