diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java index bee092c..a39594a 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java @@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.game.players.ai.MCTSAI; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -11,22 +12,18 @@ public class MCTSAI3 extends MCTSAI { private final int threads; private final ExecutorService threadPool; - public MCTSAI3(int milliseconds) { - threads = 8; - threadPool = Executors.newFixedThreadPool(8); - super(milliseconds); - } - public MCTSAI3(int milliseconds, int threads) { - this.threads = threads; - threadPool = Executors.newFixedThreadPool(threads); super(milliseconds); + + this.threads = threads; + this.threadPool = Executors.newFixedThreadPool(threads); } public MCTSAI3(MCTSAI3 other) { - threads = 8; - threadPool = Executors.newFixedThreadPool(8); super(other); + + this.threads = other.threads; + this.threadPool = other.threadPool; } @Override @@ -40,12 +37,21 @@ public class MCTSAI3 extends MCTSAI { final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + final CountDownLatch latch = new CountDownLatch(threads); + for (int i = 0; i < threads; i++) { - threadPool.submit(() -> iterate(root, endTime)); + threadPool.submit(() -> { + try { + iterate(root, endTime); + } finally { + latch.countDown(); + } + }); } try { - threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); + final long remaining = endTime - System.nanoTime(); + latch.await(remaining, TimeUnit.NANOSECONDS); lastIterations = root.visits.get(); @@ -59,14 +65,12 @@ public class MCTSAI3 extends MCTSAI { } } - private Void iterate(Node root, long endTime) { + private void iterate(Node root, long endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { Node leaf = selection(root); leaf = expansion(leaf); final int value = simulation(leaf); backPropagation(leaf, value); } - - return null; } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java index 0ec748a..630df44 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java @@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.game.players.ai.MCTSAI; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -13,25 +14,18 @@ public class MCTSAI4 extends MCTSAI { private Node root; - public MCTSAI4(int milliseconds) { - threads = 8; - threadPool = Executors.newFixedThreadPool(8); - super(milliseconds); - this.root = null; - } - public MCTSAI4(int milliseconds, int threads) { - this.threads = threads; - threadPool = Executors.newFixedThreadPool(threads); super(milliseconds); - this.root = null; + + this.threads = threads; + this.threadPool = Executors.newFixedThreadPool(threads); } public MCTSAI4(MCTSAI4 other) { - threads = 8; - threadPool = Executors.newFixedThreadPool(8); super(other); - this.root = other.root; + + this.threads = other.threads; + this.threadPool = other.threadPool; } @Override @@ -45,12 +39,21 @@ public class MCTSAI4 extends MCTSAI { final long endTime = System.nanoTime() + milliseconds * 1_000_000L; + final CountDownLatch latch = new CountDownLatch(threads); + for (int i = 0; i < threads; i++) { - threadPool.submit(() -> iterate(root, endTime)); + threadPool.submit(() -> { + try { + iterate(root, endTime); + } finally { + latch.countDown(); + } + }); } try { - threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); + final long remaining = endTime - System.nanoTime(); + latch.await(remaining, TimeUnit.NANOSECONDS); lastIterations = root.visits.get(); @@ -68,14 +71,12 @@ public class MCTSAI4 extends MCTSAI { } } - private Void iterate(Node root, long endTime) { + private void iterate(Node root, long endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { Node leaf = selection(root); leaf = expansion(leaf); final int value = simulation(leaf); backPropagation(leaf, value); } - - return null; } } \ No newline at end of file