Merge remote-tracking branch 'origin/Development' into Development

This commit is contained in:
lieght
2026-01-23 19:15:55 +01:00
2 changed files with 38 additions and 33 deletions

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -11,22 +12,18 @@ public class MCTSAI3 extends MCTSAI {
private final int threads; private final int threads;
private final ExecutorService threadPool; private final ExecutorService threadPool;
public MCTSAI3(int milliseconds) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(milliseconds);
}
public MCTSAI3(int milliseconds, int threads) { public MCTSAI3(int milliseconds, int threads) {
this.threads = threads;
threadPool = Executors.newFixedThreadPool(threads);
super(milliseconds); super(milliseconds);
this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads);
} }
public MCTSAI3(MCTSAI3 other) { public MCTSAI3(MCTSAI3 other) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(other); super(other);
this.threads = other.threads;
this.threadPool = other.threadPool;
} }
@Override @Override
@@ -40,12 +37,21 @@ public class MCTSAI3 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads);
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> iterate(root, endTime)); threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
} }
try { try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); lastIterations = root.visits.get();
@@ -59,14 +65,12 @@ public class MCTSAI3 extends MCTSAI {
} }
} }
private Void iterate(Node root, long endTime) { private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
return null;
} }
} }

View File

@@ -3,6 +3,7 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@@ -13,25 +14,18 @@ public class MCTSAI4 extends MCTSAI {
private Node root; private Node root;
public MCTSAI4(int milliseconds) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(milliseconds);
this.root = null;
}
public MCTSAI4(int milliseconds, int threads) { public MCTSAI4(int milliseconds, int threads) {
this.threads = threads;
threadPool = Executors.newFixedThreadPool(threads);
super(milliseconds); super(milliseconds);
this.root = null;
this.threads = threads;
this.threadPool = Executors.newFixedThreadPool(threads);
} }
public MCTSAI4(MCTSAI4 other) { public MCTSAI4(MCTSAI4 other) {
threads = 8;
threadPool = Executors.newFixedThreadPool(8);
super(other); super(other);
this.root = other.root;
this.threads = other.threads;
this.threadPool = other.threadPool;
} }
@Override @Override
@@ -45,12 +39,21 @@ public class MCTSAI4 extends MCTSAI {
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final CountDownLatch latch = new CountDownLatch(threads);
for (int i = 0; i < threads; i++) { for (int i = 0; i < threads; i++) {
threadPool.submit(() -> iterate(root, endTime)); threadPool.submit(() -> {
try {
iterate(root, endTime);
} finally {
latch.countDown();
}
});
} }
try { try {
threadPool.awaitTermination(milliseconds, TimeUnit.MILLISECONDS); final long remaining = endTime - System.nanoTime();
latch.await(remaining, TimeUnit.NANOSECONDS);
lastIterations = root.visits.get(); lastIterations = root.visits.get();
@@ -68,14 +71,12 @@ public class MCTSAI4 extends MCTSAI {
} }
} }
private Void iterate(Node root, long endTime) { private void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final int value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
return null;
} }
} }