3 Commits

Author SHA1 Message Date
ramollia
5d2fff7ae7 changed the way multithreading worked 2026-01-21 20:05:19 +01:00
ramollia
f168b974ab Merge remote-tracking branch 'origin/Development' into Development 2026-01-21 15:42:08 +01:00
ramollia
057487e4f9 readded the exploration constant 2026-01-21 15:40:38 +01:00
9 changed files with 106 additions and 163 deletions

View File

@@ -210,7 +210,7 @@ public final class Server {
Player[] players = new Player[2]; Player[] players = new Player[2];
players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000, Runtime.getRuntime().availableProcessors()), user); players[userStartingTurn] = new ArtificialPlayer(new MCTSAI3(1000), user);
players[opponentStartingTurn] = new OnlinePlayer(response.opponent()); players[opponentStartingTurn] = new OnlinePlayer(response.opponent());
switch (type) { switch (type) {

View File

@@ -14,11 +14,9 @@ import org.toop.app.widget.complex.ViewWidget;
import org.toop.app.widget.popup.ErrorPopup; import org.toop.app.widget.popup.ErrorPopup;
import org.toop.app.widget.tutorial.*; import org.toop.app.widget.tutorial.*;
import org.toop.game.players.ai.MiniMaxAI; import org.toop.game.players.ai.MiniMaxAI;
import org.toop.game.players.ai.RandomAI;
import org.toop.game.players.ai.mcts.MCTSAI1; import org.toop.game.players.ai.mcts.MCTSAI1;
import org.toop.game.players.ai.mcts.MCTSAI3; import org.toop.game.players.ai.mcts.MCTSAI3;
import org.toop.game.players.ai.mcts.MCTSAI4; import org.toop.game.players.ai.mcts.MCTSAI4;
import org.toop.game.players.ai.mcts.OPMCTSAI;
import org.toop.local.AppContext; import org.toop.local.AppContext;
import javafx.geometry.Pos; import javafx.geometry.Pos;
@@ -85,12 +83,12 @@ public class LocalMultiplayerView extends ViewWidget {
players[0] = new LocalPlayer(information.players[0].name); players[0] = new LocalPlayer(information.players[0].name);
} else { } else {
// players[0] = new ArtificialPlayer(new RandomAI(), "Random AI"); // players[0] = new ArtificialPlayer(new RandomAI(), "Random AI");
players[0] = new ArtificialPlayer(new MCTSAI4(500, 4), "MCTS V4 AI"); players[0] = new ArtificialPlayer(new MCTSAI1(100), "MCTS V1 AI");
} }
if (information.players[1].isHuman) { if (information.players[1].isHuman) {
players[1] = new LocalPlayer(information.players[1].name); players[1] = new LocalPlayer(information.players[1].name);
} else { } else {
players[1] = new ArtificialPlayer(new MCTSAI1(500), "MCTS V1 AI"); players[1] = new ArtificialPlayer(new MCTSAI4(100), "MCTS V4 AI");
} }
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) { if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
new ShowEnableTutorialWidget( new ShowEnableTutorialWidget(

View File

@@ -3,9 +3,7 @@ package org.toop.framework.game.games.reversi;
import org.toop.framework.game.BitboardGame; import org.toop.framework.game.BitboardGame;
import org.toop.framework.gameFramework.GameState; import org.toop.framework.gameFramework.GameState;
import org.toop.framework.gameFramework.model.game.PlayResult; import org.toop.framework.gameFramework.model.game.PlayResult;
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.Player; import org.toop.framework.gameFramework.model.player.Player;
import org.toop.framework.game.BitboardGame;
public class BitboardReversi extends BitboardGame { public class BitboardReversi extends BitboardGame {

View File

@@ -4,9 +4,12 @@ import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.framework.gameFramework.model.player.AbstractAI; import org.toop.framework.gameFramework.model.player.AbstractAI;
import java.util.Random; import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class MCTSAI extends AbstractAI { public abstract class MCTSAI extends AbstractAI {
protected static class Node { protected static class Node {
public static final int VIRTUAL_LOSS = -1;
public TurnBasedGame state; public TurnBasedGame state;
public long move; public long move;
@@ -15,8 +18,8 @@ public abstract class MCTSAI extends AbstractAI {
public Node parent; public Node parent;
public Node[] children; public Node[] children;
public float value; public AtomicInteger value;
public int visits; public AtomicInteger visits;
public float heuristic; public float heuristic;
@@ -33,8 +36,8 @@ public abstract class MCTSAI extends AbstractAI {
this.parent = parent; this.parent = parent;
this.children = new Node[Long.bitCount(legalMoves)]; this.children = new Node[Long.bitCount(legalMoves)];
this.value = 0.0f; this.value = new AtomicInteger(0);
this.visits = 0; this.visits = new AtomicInteger(0);
this.heuristic = state.rateMove(move); this.heuristic = state.rateMove(move);
@@ -53,14 +56,14 @@ public abstract class MCTSAI extends AbstractAI {
return unexpandedMoves == 0L; return unexpandedMoves == 0L;
} }
public float calculateUCT(int parentVisits) { public float calculateUCT(float explorationFactor) {
if (visits == 0) { if (visits.get() == 0) {
return Float.POSITIVE_INFINITY; return Float.POSITIVE_INFINITY;
} }
final float exploitation = value / visits; final float exploitation = (float) value.get() / visits.get();
final float exploration = (float)(Math.sqrt(Math.log(parentVisits) / visits)); final float exploration = (float)(Math.sqrt(explorationFactor / visits.get()));
final float bias = heuristic * 10.0f / (visits + 1); final float bias = heuristic * 10.0f / (visits.get() + 1);
return exploitation + exploration + bias; return exploitation + exploration + bias;
} }
@@ -72,7 +75,7 @@ public abstract class MCTSAI extends AbstractAI {
float highestUCT = Float.NEGATIVE_INFINITY; float highestUCT = Float.NEGATIVE_INFINITY;
for (int i = 0; i < expanded; i++) { for (int i = 0; i < expanded; i++) {
final float childUCT = children[i].calculateUCT(visits); final float childUCT = children[i].calculateUCT(2.0f * (float)Math.log(visits.get()));
if (childUCT > highestUCT) { if (childUCT > highestUCT) {
highestUCTChild = children[i]; highestUCTChild = children[i];
@@ -108,31 +111,39 @@ public abstract class MCTSAI extends AbstractAI {
protected Node selection(Node root) { protected Node selection(Node root) {
while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) { while (Float.isNaN(root.solved) && root.isFullyExpanded() && !root.state.isTerminal()) {
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
root = root.bestUCTChild(); root = root.bestUCTChild();
} }
root.value.addAndGet(Node.VIRTUAL_LOSS);
root.visits.incrementAndGet();
return root; return root;
} }
protected Node expansion(Node leaf) { protected Node expansion(Node leaf) {
if (leaf.unexpandedMoves == 0L) { synchronized (leaf) {
return leaf; if (leaf.unexpandedMoves == 0L) {
return leaf;
}
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
} }
final long unexpandedMove = leaf.unexpandedMoves & -leaf.unexpandedMoves;
final TurnBasedGame copiedState = leaf.state.deepCopy();
copiedState.play(unexpandedMove);
final Node expandedChild = new Node(copiedState, leaf, unexpandedMove);
leaf.children[leaf.getExpanded()] = expandedChild;
leaf.unexpandedMoves &= ~unexpandedMove;
return expandedChild;
} }
protected float simulation(Node leaf) { protected int simulation(Node leaf) {
final TurnBasedGame copiedState = leaf.state.deepCopy(); final TurnBasedGame copiedState = leaf.state.deepCopy();
final int playerIndex = 1 - copiedState.getCurrentTurn(); final int playerIndex = 1 - copiedState.getCurrentTurn();
@@ -144,20 +155,20 @@ public abstract class MCTSAI extends AbstractAI {
} }
if (copiedState.getWinner() == playerIndex) { if (copiedState.getWinner() == playerIndex) {
return 1.0f; return 1;
} }
if (copiedState.getWinner() >= 0) { if (copiedState.getWinner() >= 0) {
return -1.0f; return -1;
} }
return 0.0f; return 0;
} }
protected void backPropagation(Node leaf, float value) { protected void backPropagation(Node leaf, int value) {
while (leaf != null) { while (leaf != null) {
leaf.value += value; value -= Node.VIRTUAL_LOSS;
leaf.visits++; leaf.value.addAndGet(value);
if (Float.isNaN(leaf.solved)) { if (Float.isNaN(leaf.solved)) {
updateSolvedStatus(leaf); updateSolvedStatus(leaf);
@@ -175,9 +186,9 @@ public abstract class MCTSAI extends AbstractAI {
int mostVisited = -1; int mostVisited = -1;
for (int i = 0; i < expanded; i++) { for (int i = 0; i < expanded; i++) {
if (root.children[i].visits > mostVisited) { if (root.children[i].visits.get() > mostVisited) {
mostVisitedChild = root.children[i]; mostVisitedChild = root.children[i];
mostVisited = root.children[i].visits; mostVisited = root.children[i].visits.get();
} }
} }

View File

@@ -26,12 +26,11 @@ public class MCTSAI1 extends MCTSAI {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final float value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
lastIterations = root.visits; lastIterations = root.visits.get();
IO.println("V1: " + lastIterations);
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move; return mostVisitedChild.move;

View File

@@ -32,12 +32,11 @@ public class MCTSAI2 extends MCTSAI {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) { while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root); Node leaf = selection(root);
leaf = expansion(leaf); leaf = expansion(leaf);
final float value = simulation(leaf); final int value = simulation(leaf);
backPropagation(leaf, value); backPropagation(leaf, value);
} }
lastIterations = root.visits; lastIterations = root.visits.get();
IO.println("V2: " + lastIterations);
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move; final long move = mostVisitedChild.move;

View File

@@ -3,26 +3,21 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.TimeUnit;
public class MCTSAI3 extends MCTSAI { public class MCTSAI3 extends MCTSAI {
private final int threads; private static final int THREADS = 8;
public MCTSAI3(int milliseconds, int threads) { private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
public MCTSAI3(int milliseconds) {
super(milliseconds); super(milliseconds);
this.threads = threads;
} }
public MCTSAI3(MCTSAI3 other) { public MCTSAI3(MCTSAI3 other) {
super(other); super(other);
this.threads = other.threads;
} }
@Override @Override
@@ -32,53 +27,18 @@ public class MCTSAI3 extends MCTSAI {
@Override @Override
public long getMove(TurnBasedGame game) { public long getMove(TurnBasedGame game) {
final ExecutorService pool = Executors.newFixedThreadPool(threads); final Node root = new Node(game.deepCopy(), null, 0L);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>(); for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
for (int i = 0; i < threads; i++) {
tasks.add(() -> {
final Node localRoot = new Node(game.deepCopy());
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
} }
try { try {
final List<Future<Node>> results = pool.invokeAll(tasks); threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
pool.shutdown(); lastIterations = root.visits.get();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
IO.println("V3: " + lastIterations);
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
return mostVisitedChild.move; return mostVisitedChild.move;
@@ -89,4 +49,15 @@ public class MCTSAI3 extends MCTSAI {
return randomSetBit(legalMoves); return randomSetBit(legalMoves);
} }
} }
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
} }

View File

@@ -3,29 +3,27 @@ package org.toop.game.players.ai.mcts;
import org.toop.framework.gameFramework.model.game.TurnBasedGame; import org.toop.framework.gameFramework.model.game.TurnBasedGame;
import org.toop.game.players.ai.MCTSAI; import org.toop.game.players.ai.MCTSAI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.TimeUnit;
public class MCTSAI4 extends MCTSAI { public class MCTSAI4 extends MCTSAI {
private final int threads; private static final int THREADS = Runtime.getRuntime().availableProcessors();
private final Node[] threadRoots;
public MCTSAI4(int milliseconds, int threads) { private static final ExecutorService threadPool = Executors.newFixedThreadPool(THREADS);
private Node root;
public MCTSAI4(int milliseconds) {
super(milliseconds); super(milliseconds);
this.threads = threads; this.root = null;
this.threadRoots = new Node[threads];
} }
public MCTSAI4(MCTSAI4 other) { public MCTSAI4(MCTSAI4 other) {
super(other); super(other);
this.threads = other.threads; this.root = other.root;
this.threadRoots = other.threadRoots;
} }
@Override @Override
@@ -35,66 +33,23 @@ public class MCTSAI4 extends MCTSAI {
@Override @Override
public long getMove(TurnBasedGame game) { public long getMove(TurnBasedGame game) {
for (int i = 0; i < threads; i++) { root = findOrResetRoot(root, game);
threadRoots[i] = findOrResetRoot(threadRoots[i], game);
}
final ExecutorService pool = Executors.newFixedThreadPool(threads);
final long endTime = System.nanoTime() + milliseconds * 1_000_000L; final long endTime = System.nanoTime() + milliseconds * 1_000_000L;
final List<Callable<Node>> tasks = new ArrayList<>(); for (int i = 0; i < THREADS; i++) {
threadPool.submit(() -> iterate(root, endTime));
for (int i = 0; i < threads; i++) {
final int threadIndex = i;
tasks.add(() -> {
final Node localRoot = threadRoots[threadIndex];
while (Float.isNaN(localRoot.solved) && System.nanoTime() < endTime) {
Node leaf = selection(localRoot);
leaf = expansion(leaf);
final float value = simulation(leaf);
backPropagation(leaf, value);
}
return localRoot;
});
} }
try { try {
final List<Future<Node>> results = pool.invokeAll(tasks); threadPool.awaitTermination(milliseconds + 50, TimeUnit.MILLISECONDS);
pool.shutdown(); lastIterations = root.visits.get();
final Node root = new Node(game.deepCopy());
for (int i = 0; i < root.children.length; i++) {
expansion(root);
}
for (final Future<Node> result : results) {
final Node localRoot = result.get();
for (final Node localChild : localRoot.children) {
for (int i = 0; i < root.children.length; i++) {
if (localChild.move == root.children[i].move) {
root.children[i].visits += localChild.visits;
root.visits += localChild.visits;
break;
}
}
}
}
lastIterations = root.visits;
IO.println("V4: " + lastIterations);
final Node mostVisitedChild = mostVisitedChild(root); final Node mostVisitedChild = mostVisitedChild(root);
final long move = mostVisitedChild.move; final long move = mostVisitedChild.move;
for (int i = 0; i < threads; i++) { root = findChildByMove(root, move);
threadRoots[i] = findChildByMove(threadRoots[i], move);
}
return move; return move;
} catch (Exception _) { } catch (Exception _) {
@@ -104,4 +59,15 @@ public class MCTSAI4 extends MCTSAI {
return randomSetBit(legalMoves); return randomSetBit(legalMoves);
} }
} }
private Void iterate(Node root, long endTime) {
while (Float.isNaN(root.solved) && System.nanoTime() < endTime) {
Node leaf = selection(root);
leaf = expansion(leaf);
final int value = simulation(leaf);
backPropagation(leaf, value);
}
return null;
}
} }

View File

@@ -25,11 +25,11 @@ public class AITest {
@BeforeAll @BeforeAll
public static void init() { public static void init() {
var versions = new ArtificialPlayer[2]; var versions = new ArtificialPlayer[4];
// versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1"); versions[0] = new ArtificialPlayer(new MCTSAI1(10), "MCTS V1");
// versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2"); versions[1] = new ArtificialPlayer(new MCTSAI2(10), "MCTS V2");
versions[0] = new ArtificialPlayer(new MCTSAI3(10, 8), "MCTS V3"); versions[2] = new ArtificialPlayer(new MCTSAI3(10), "MCTS V3");
versions[1] = new ArtificialPlayer(new MCTSAI4(10, 8), "MCTS V4"); versions[3] = new ArtificialPlayer(new MCTSAI4(10), "MCTS V4");
for (int i = 0; i < versions.length; i++) { for (int i = 0; i < versions.length; i++) {
for (int j = i + 1; j < versions.length; j++) { for (int j = i + 1; j < versions.length; j++) {
final int playerIndex1 = i % versions.length; final int playerIndex1 = i % versions.length;
@@ -96,7 +96,8 @@ public class AITest {
addGameData(new GameData( addGameData(new GameData(
AI1, AI1,
AI2, AI2,
getWinnerForMatch(AI1, AI2, match getWinnerForMatch(AI1, AI2, match),
match.getAmountOfTurns(),
millisecondscounterAI1, millisecondscounterAI1,
millisecondscounterAI2 millisecondscounterAI2
)); ));