changed the way multithreading worked

This commit is contained in:
ramollia
2026-01-21 20:05:19 +01:00
parent f168b974ab
commit 5d2fff7ae7
9 changed files with 106 additions and 163 deletions

View File

@@ -4,9 +4,12 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class MCTSAI extends AbstractAI {
protected static class Node {
public static final int VIRTUAL_LOSS = -1;
public TurnBasedGame state;
public long move;
@@ -15,8 +18,8 @@ public abstract class MCTSAI extends AbstractAI {
public Node parent;
public Node[] children;
public float value;
public int visits;
public AtomicInteger value;
public AtomicInteger visits;
public float heuristic;
@@ -33,8 +36,8 @@ public abstract class MCTSAI extends AbstractAI {
this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)];
this.value = 0.0f;
this.visits = 0;
this.value = new AtomicInteger(0);
this.visits = new AtomicInteger(0);
this.heuristic = state.rateMove(move);
@@ -53,14 +56,14 @@ public abstract class MCTSAI extends AbstractAI {
return unexpandedMoves == 0L;
}
public float calculateUCT(int parentVisits) {
if (visits == 0) {
public float calculateUCT(float explorationFactor) {
if (visits.get() == 0) {
return Float.POSITIVE_INFINITY;
}
final float exploitation = value / visits;
final float exploration = 1.4141f * (float)(Math.sqrt(Math.log(parentVisits) / visits));
final float bias = heuristic * 10.0f / (visits + 1);
final float exploitation = (float) value.get() / visits.get();
final float exploration = (float)(Math.sqrt(explorationFactor / visits.get()));
final float bias = heuristic * 10.0f / (visits.get() + 1);
return exploitation + exploration + bias;
}
@@ -72,7 +75,7 @@ public abstract class MCTSAI extends AbstractAI {
float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits);
final float childUCT = children[i].calculateUCT(2.0f * (float)Math.log(visits.get()));
if (childUCT > highestUCT) {
highestUCTChild = children[i];
@@ -108,31 +111,39 @@ public abstract class MCTSAI extends AbstractAI {
protected Node selection(Node root) {
while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
root = root.bestUCTChild();
}
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
return root;
}
protected Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
synchronized (leaf) {
if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
}
protected float simulation(Node leaf) {
protected int simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn();
@@ -144,20 +155,20 @@ public abstract class MCTSAI extends AbstractAI {
}
if (copiedState.getWinner() == playerIndex) {
return 1.0f;
return 1;
}
if (copiedState.getWinner() >= 0) {
return -1.0f;
return -1;
}
return 0.0f;
return 0;
}
protected void backPropagation(Node leaf, float value) {
protected void backPropagation(Node leaf, int value) {
while (leaf != null) {
leaf.value += value;
leaf.visits++;
value -= Node.VIRTUAL_LOSS;
leaf.value.addAndGet(value);
if (Float.isNaN(leaf.solved)) {
updateSolvedStatus(leaf);
@@ -175,9 +186,9 @@ public abstract class MCTSAI extends AbstractAI {
int mostVisited = -1;
for (int i = 0; i < expanded; i++) {
if (root.children[i].visits > mostVisited) {
if (root.children[i].visits.get() > mostVisited) {
mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits;
mostVisited = root.children[i].visits.get();
}
}

View File

@@ -26,12 +26,11 @@ public class MCTSAI1 extends MCTSAI {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits;
IO.println("V1: " + lastIterations);
lastIterations = root.visits.get();
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;

View File

@@ -32,12 +32,11 @@ public class MCTSAI2 extends MCTSAI {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final float value = simulation(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
lastIterations = root.visits;
IO.println("V2: " + lastIterations);
lastIterations = root.visits.get();
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;

View File

@@ -3,26 +3,21 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
public class MCTSAI3 extends MCTSAI {
private final int threads;
private static final int THREADS = 8;
public MCTSAI3(int milliseconds, int threads) {
private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
public MCTSAI3(int milliseconds) {
super(milliseconds);
this.threads = threads;
}
public MCTSAI3(MCTSAI3 other) {
super(other);
this.threads = other.threads;
}
@Override
@@ -32,53 +27,18 @@ public class MCTSAI3 extends MCTSAI {
@Override
public long getMove(TurnBasedGame game) {
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final Node root = new Node(game.deepCopy(), null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
tasks.add(() -> {
final Node localRoot = new Node(game.deepCopy());
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
IO.println("V3: " + lastIterations);
lastIterations = root.visits.get();
final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move;
@@ -89,4 +49,15 @@ public class MCTSAI3 extends MCTSAI {
return randomSetBit(legalMoves);
}
}
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}

View File

@@ -3,29 +3,27 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
public class MCTSAI4 extends MCTSAI {
private final int threads;
private final Node[] threadRoots;
private static final int THREADS = Runtime.getRuntime().availableProcessors();
public MCTSAI4(int milliseconds, int threads) {
private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
private Node root;
public MCTSAI4(int milliseconds) {
super(milliseconds);
this.threads = threads;
this.threadRoots = new Node[threads];
this.root = null;
}
public MCTSAI4(MCTSAI4 other) {
super(other);
this.threads = other.threads;
this.threadRoots = other.threadRoots;
this.root = other.root;
}
@Override
@@ -35,66 +33,23 @@ public class MCTSAI4 extends MCTSAI {
@Override
public long getMove(TurnBasedGame game) {
for (int i = 0; i < threads; i++) {
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
root = findOrResetRoot(root, game);
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>();
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
}
try {
final List<Future<Node>> results = pool.invokeAll(tasks);
threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
pool.shutdown();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
IO.println("V4: " + lastIterations);
lastIterations = root.visits.get();
final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move;
for (int i = 0; i < threads; i++) {
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
root = findChildByMove(root, move);
return move;
} catch (Exception _) {
@@ -104,4 +59,15 @@ public class MCTSAI4 extends MCTSAI {
return randomSetBit(legalMoves);
}
}
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
}

View File

@@ -25,11 +25,11 @@ public class AITest {
@BeforeAll
public static void init() {
var versions = new ArtificialPlayer[2];
// versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
// versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[0] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3");
versions[1] = new ArtificialPlayer(new MCTSAI4(10, 8), "MCTS V4");
var versions = new ArtificialPlayer[4];
versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3");
versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4");
for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length;
@@ -96,7 +96,8 @@ public class AITest {
addGameData(new GameData(
AI1,
AI2,
getWinnerForMatch(AI1, AI2, match
getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
millisecondscounterAI1,
millisecondscounterAI2
));