From 5d2fff7ae7c91c5dae5ddf6f44227f6ee9713e8b Mon Sep 17 00:00:00 2001 From: ramollia <> Date: Wed, 21 Jan 2026 20:05:19 +0100 Subject: [PATCH] changed the way multithreading worked --- app/src/main/java/org/toop/app/Server.java | 2 +- .../app/widget/view/LocalMultiplayerView.java | 6 +- .../game/games/reversi/BitboardReversi.java | 2 - .../java/org/toop/game/players/ai/MCTSAI.java | 77 ++++++++++------- .../toop/game/players/ai/mcts/MCTSAI1.java | 5 +- .../toop/game/players/ai/mcts/MCTSAI2.java | 5 +- .../toop/game/players/ai/mcts/MCTSAI3.java | 73 +++++----------- .../toop/game/players/ai/mcts/MCTSAI4.java | 86 ++++++------------- game/src/test/java/research/AITest.java | 13 +-- 9 files changed, 106 insertions(+), 163 deletions(-) diff --git a/app/src/main/java/org/toop/app/Server.java b/app/src/main/java/org/toop/app/Server.java index 06a5b75..75cc0f5 100644 --- a/app/src/main/java/org/toop/app/Server.java +++ b/app/src/main/java/org/toop/app/Server.java @@ -210,7 +210,7 @@ public final class Server { Player[] players = new Player[2]; - players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, Runtime.getRuntime().availableProcessors()), user); + players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000), user); players[opponentStartingTurn] = new OnlinePlayer(response.opponent()); switch (type) { diff --git a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java index c7cf20e..e44875b 100644 --- a/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java +++ b/app/src/main/java/org/toop/app/widget/view/LocalMultiplayerView.java @@ -14,11 +14,9 @@ import org.toop.app.widget.complex.ViewWidget; import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.tutorial.*; import org.toop.game.players.ai.MiniMaxAI; -import org.toop.game.players.ai.RandomAI; import org.toop.game.players.ai.mcts.MCTSAI1; import org.toop.game.players.ai.mcts.MCTSAI3; import org.toop.game.players.ai.mcts.MCTSAI4; -import org.toop.game.players.ai.mcts.OPMCTSAI; import org.toop.local.AppContext; import javafx.geometry.Pos; @@ -85,12 +83,12 @@ public class LocalMultiplayerView extends ViewWidget { players[0] = new LocalPlayer(information.players[0].name); } else { // players[0] = new ArtificialPlayer(new RandomAI(), "Random AI"); - players[0] = new ArtificialPlayer(new MCTSAI4(500, 4), "MCTS V4 AI"); + players[0] = new ArtificialPlayer(new MCTSAI1(100), "MCTS V1 AI"); } if (information.players[1].isHuman) { players[1] = new LocalPlayer(information.players[1].name); } else { - players[1] = new ArtificialPlayer(new MCTSAI1(500), "MCTS V1 AI"); + players[1] = new ArtificialPlayer(new MCTSAI4(100), "MCTS V4 AI"); } if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { new ShowEnableTutorialWidget( diff --git a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java index c9429bc..ba5e8c7 100644 --- a/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java +++ b/framework/src/main/java/org/toop/framework/game/games/reversi/BitboardReversi.java @@ -3,9 +3,7 @@ package org.toop.framework.game.games.reversi; import org.toop.framework.game.BitboardGame; import org.toop.framework.gameFramework.GameState; import org.toop.framework.gameFramework.model.game.PlayResult; -import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.Player; -import org.toop.framework.game.BitboardGame; public class BitboardReversi extends BitboardGame { diff --git a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java index 8a95505..9ac0d1c 100644 --- a/game/src/main/java/org/toop/game/players/ai/MCTSAI.java +++ b/game/src/main/java/org/toop/game/players/ai/MCTSAI.java @@ -4,9 +4,12 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.player.AbstractAI; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; public abstract class MCTSAI extends AbstractAI { protected static class Node { + public static final int VIRTUAL_LOSS = -1; + public TurnBasedGame state; public long move; @@ -15,8 +18,8 @@ public abstract class MCTSAI extends AbstractAI { public Node parent; public Node[] children; - public float value; - public int visits; + public AtomicInteger value; + public AtomicInteger visits; public float heuristic; @@ -33,8 +36,8 @@ public abstract class MCTSAI extends AbstractAI { this.parent = parent; this.children = new Node[Long.bitCount(legalMoves)]; - this.value = 0.0f; - this.visits = 0; + this.value = new AtomicInteger(0); + this.visits = new AtomicInteger(0); this.heuristic = state.rateMove(move); @@ -53,14 +56,14 @@ public abstract class MCTSAI extends AbstractAI { return unexpandedMoves == 0L; } - public float calculateUCT(int parentVisits) { - if (visits == 0) { + public float calculateUCT(float explorationFactor) { + if (visits.get() == 0) { return Float.POSITIVE_INFINITY; } - final float exploitation = value / visits; - final float exploration = 1.4141f * (float)(Math.sqrt(Math.log(parentVisits) / visits)); - final float bias = heuristic * 10.0f / (visits + 1); + final float exploitation = (float) value.get() / visits.get(); + final float exploration = (float)(Math.sqrt(explorationFactor / visits.get())); + final float bias = heuristic * 10.0f / (visits.get() + 1); return exploitation + exploration + bias; } @@ -72,7 +75,7 @@ public abstract class MCTSAI extends AbstractAI { float highestUCT = Float.NEGATIVE_INFINITY; for (int i = 0; i < expanded; i++) { - final float childUCT = children[i].calculateUCT(visits); + final float childUCT = children[i].calculateUCT(2.0f * (float)Math.log(visits.get())); if (childUCT > highestUCT) { highestUCTChild = children[i]; @@ -108,31 +111,39 @@ public abstract class MCTSAI extends AbstractAI { protected Node selection(Node root) { while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) { + root.value.addAndGet(Node.VIRTUAL_LOSS); + root.visits.incrementAndGet(); + root = root.bestUCTChild(); } + root.value.addAndGet(Node.VIRTUAL_LOSS); + root.visits.incrementAndGet(); + return root; } protected Node expansion(Node leaf) { - if (leaf.unexpandedMoves == 0L) { - return leaf; + synchronized (leaf) { + if (leaf.unexpandedMoves == 0L) { + return leaf; + } + + final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; + + final TurnBasedGame copiedState = leaf.state.deepCopy(); + copiedState.play(unexpandedMove); + + final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); + + leaf.children[leaf.getExpanded()] = expandedChild; + leaf.unexpandedMoves &= ~unexpandedMove; + + return expandedChild; } - - final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves; - - final TurnBasedGame copiedState = leaf.state.deepCopy(); - copiedState.play(unexpandedMove); - - final Node expandedChild = new Node(copiedState, leaf, unexpandedMove); - - leaf.children[leaf.getExpanded()] = expandedChild; - leaf.unexpandedMoves &= ~unexpandedMove; - - return expandedChild; } - protected float simulation(Node leaf) { + protected int simulation(Node leaf) { final TurnBasedGame copiedState = leaf.state.deepCopy(); final int playerIndex = 1 - copiedState.getCurrentTurn(); @@ -144,20 +155,20 @@ public abstract class MCTSAI extends AbstractAI { } if (copiedState.getWinner() == playerIndex) { - return 1.0f; + return 1; } if (copiedState.getWinner() >= 0) { - return -1.0f; + return -1; } - return 0.0f; + return 0; } - protected void backPropagation(Node leaf, float value) { + protected void backPropagation(Node leaf, int value) { while (leaf != null) { - leaf.value += value; - leaf.visits++; + value -= Node.VIRTUAL_LOSS; + leaf.value.addAndGet(value); if (Float.isNaN(leaf.solved)) { updateSolvedStatus(leaf); @@ -175,9 +186,9 @@ public abstract class MCTSAI extends AbstractAI { int mostVisited = -1; for (int i = 0; i < expanded; i++) { - if (root.children[i].visits > mostVisited) { + if (root.children[i].visits.get() > mostVisited) { mostVisitedChild = root.children[i]; - mostVisited = root.children[i].visits; + mostVisited = root.children[i].visits.get(); } } diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java index 0a67b9d..b12f959 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI1.java @@ -26,12 +26,11 @@ public class MCTSAI1 extends MCTSAI { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { Node leaf = selection(root); leaf = expansion(leaf); - final float value = simulation(leaf); + final int value = simulation(leaf); backPropagation(leaf, value); } - lastIterations = root.visits; - IO.println("V1: " + lastIterations); + lastIterations = root.visits.get(); final Node mostVisitedChild = mostVisitedChild(root); return mostVisitedChild.move; diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java index c0549cd..9ce20de 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI2.java @@ -32,12 +32,11 @@ public class MCTSAI2 extends MCTSAI { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { Node leaf = selection(root); leaf = expansion(leaf); - final float value = simulation(leaf); + final int value = simulation(leaf); backPropagation(leaf, value); } - lastIterations = root.visits; - IO.println("V2: " + lastIterations); + lastIterations = root.visits.get(); final Node mostVisitedChild = mostVisitedChild(root); final long move = mostVisitedChild.move; diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java index bb1682e..82a802b 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI3.java @@ -3,26 +3,21 @@ package org.toop.game.players.ai.mcts; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.game.players.ai.MCTSAI; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; public class MCTSAI3 extends MCTSAI { - private final int threads; + private static final int THREADS = 8; - public MCTSAI3(int milliseconds, int threads) { + private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS); + + public MCTSAI3(int milliseconds) { super(milliseconds); - - this.threads = threads; } public MCTSAI3(MCTSAI3 other) { super(other); - - this.threads = other.threads; } @Override @@ -32,53 +27,18 @@ public class MCTSAI3 extends MCTSAI { @Override public long getMove(TurnBasedGame game) { - final ExecutorService pool = Executors.newFixedThreadPool(threads); + final Node root = new Node(game.deepCopy(), null, 0L); + final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - final List> tasks = new ArrayList<>(); - - for (int i = 0; i < threads; i++) { - tasks.add(() -> { - final Node localRoot = new Node(game.deepCopy()); - - while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { - Node leaf = selection(localRoot); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - return localRoot; - }); + for (int i = 0; i < THREADS; i++) { + threadPool.submit(() -> iterate(root, endTime)); } try { - final List> results = pool.invokeAll(tasks); + threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS); - pool.shutdown(); - - final Node root = new Node(game.deepCopy()); - - for (int i = 0; i < root.children.length; i++) { - expansion(root); - } - - for (final Future result : results) { - final Node localRoot = result.get(); - - for (final Node localChild : localRoot.children) { - for (int i = 0; i < root.children.length; i++) { - if (localChild.move == root.children[i].move) { - root.children[i].visits += localChild.visits; - root.visits += localChild.visits; - break; - } - } - } - } - - lastIterations = root.visits; - IO.println("V3: " + lastIterations); + lastIterations = root.visits.get(); final Node mostVisitedChild = mostVisitedChild(root); return mostVisitedChild.move; @@ -89,4 +49,15 @@ public class MCTSAI3 extends MCTSAI { return randomSetBit(legalMoves); } } + + private Void iterate(Node root, long endTime) { + while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final int value = simulation(leaf); + backPropagation(leaf, value); + } + + return null; + } } \ No newline at end of file diff --git a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java index b88c0ed..2f55b2f 100644 --- a/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java +++ b/game/src/main/java/org/toop/game/players/ai/mcts/MCTSAI4.java @@ -3,29 +3,27 @@ package org.toop.game.players.ai.mcts; import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.game.players.ai.MCTSAI; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; public class MCTSAI4 extends MCTSAI { - private final int threads; - private final Node[] threadRoots; + private static final int THREADS = Runtime.getRuntime().availableProcessors(); - public MCTSAI4(int milliseconds, int threads) { + private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS); + + private Node root; + + public MCTSAI4(int milliseconds) { super(milliseconds); - this.threads = threads; - this.threadRoots = new Node[threads]; + this.root = null; } public MCTSAI4(MCTSAI4 other) { super(other); - this.threads = other.threads; - this.threadRoots = other.threadRoots; + this.root = other.root; } @Override @@ -35,66 +33,23 @@ public class MCTSAI4 extends MCTSAI { @Override public long getMove(TurnBasedGame game) { - for (int i = 0; i < threads; i++) { - threadRoots[i] = findOrResetRoot(threadRoots[i], game); - } + root = findOrResetRoot(root, game); - final ExecutorService pool = Executors.newFixedThreadPool(threads); final long endTime = System.nanoTime() + milliseconds * 1_000_000L; - final List> tasks = new ArrayList<>(); - - for (int i = 0; i < threads; i++) { - final int threadIndex = i; - - tasks.add(() -> { - final Node localRoot = threadRoots[threadIndex]; - - while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) { - Node leaf = selection(localRoot); - leaf = expansion(leaf); - final float value = simulation(leaf); - backPropagation(leaf, value); - } - - return localRoot; - }); + for (int i = 0; i < THREADS; i++) { + threadPool.submit(() -> iterate(root, endTime)); } try { - final List> results = pool.invokeAll(tasks); + threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS); - pool.shutdown(); - - final Node root = new Node(game.deepCopy()); - - for (int i = 0; i < root.children.length; i++) { - expansion(root); - } - - for (final Future result : results) { - final Node localRoot = result.get(); - - for (final Node localChild : localRoot.children) { - for (int i = 0; i < root.children.length; i++) { - if (localChild.move == root.children[i].move) { - root.children[i].visits += localChild.visits; - root.visits += localChild.visits; - break; - } - } - } - } - - lastIterations = root.visits; - IO.println("V4: " + lastIterations); + lastIterations = root.visits.get(); final Node mostVisitedChild = mostVisitedChild(root); final long move = mostVisitedChild.move; - for (int i = 0; i < threads; i++) { - threadRoots[i] = findChildByMove(threadRoots[i], move); - } + root = findChildByMove(root, move); return move; } catch (Exception _) { @@ -104,4 +59,15 @@ public class MCTSAI4 extends MCTSAI { return randomSetBit(legalMoves); } } + + private Void iterate(Node root, long endTime) { + while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { + Node leaf = selection(root); + leaf = expansion(leaf); + final int value = simulation(leaf); + backPropagation(leaf, value); + } + + return null; + } } \ No newline at end of file diff --git a/game/src/test/java/research/AITest.java b/game/src/test/java/research/AITest.java index 498f91d..8cc97cd 100644 --- a/game/src/test/java/research/AITest.java +++ b/game/src/test/java/research/AITest.java @@ -25,11 +25,11 @@ public class AITest { @BeforeAll public static void init() { - var versions = new ArtificialPlayer[2]; -// versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1"); -// versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2"); - versions[0] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3"); - versions[1] = new ArtificialPlayer(new MCTSAI4(10, 8), "MCTS V4"); + var versions = new ArtificialPlayer[4]; + versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1"); + versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2"); + versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3"); + versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4"); for (int i = 0; i < versions.length; i++) { for (int j = i + 1; j < versions.length; j++) { final int playerIndex1 = i % versions.length; @@ -96,7 +96,8 @@ public class AITest { addGameData(new GameData( AI1, AI2, - getWinnerForMatch(AI1, AI2, match + getWinnerForMatch(AI1, AI2, match), + match.getAmountOfTurns(), millisecondscounterAI1, millisecondscounterAI2 ));