mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 02:44:50 +00:00
mcts v1
This commit is contained in:
@@ -15,6 +15,7 @@ import org.toop.app.widget.complex.PlayerInfoWidget;
|
||||
import org.toop.app.widget.complex.ViewWidget;
|
||||
import org.toop.app.widget.popup.ErrorPopup;
|
||||
import org.toop.app.widget.tutorial.*;
|
||||
import org.toop.game.players.ai.MCTSAI;
|
||||
import org.toop.game.players.ai.MiniMaxAI;
|
||||
import org.toop.game.players.ai.RandomAI;
|
||||
import org.toop.local.AppContext;
|
||||
@@ -87,7 +88,7 @@ public class LocalMultiplayerView extends ViewWidget {
|
||||
if (information.players[1].isHuman) {
|
||||
players[1] = new LocalPlayer<>(information.players[1].name);
|
||||
} else {
|
||||
players[1] = new ArtificialPlayer<>(new MiniMaxAI<BitboardReversi>(6), "MiniMax");
|
||||
players[1] = new ArtificialPlayer<>(new MCTSAI<BitboardReversi>(1000), "MCTS AI");
|
||||
}
|
||||
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
|
||||
new ShowEnableTutorialWidget(
|
||||
|
||||
191
game/src/main/java/org/toop/game/players/ai/MCTSAI.java
Normal file
191
game/src/main/java/org/toop/game/players/ai/MCTSAI.java
Normal file
@@ -0,0 +1,191 @@
|
||||
package org.toop.game.players.ai;
|
||||
|
||||
import org.toop.framework.gameFramework.GameState;
|
||||
import org.toop.framework.gameFramework.model.game.PlayResult;
|
||||
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||
private static class Node {
|
||||
public TurnBasedGame<?> state;
|
||||
public long move;
|
||||
|
||||
public Node parent;
|
||||
|
||||
public int expanded;
|
||||
public Node[] children;
|
||||
|
||||
public int visits;
|
||||
public float value;
|
||||
|
||||
public Node(TurnBasedGame<?> state, long move, Node parent) {
|
||||
this.state = state;
|
||||
this.move = move;
|
||||
|
||||
this.parent = parent;
|
||||
|
||||
this.expanded = 0;
|
||||
this.children = new Node[Long.bitCount(state.getLegalMoves())];
|
||||
|
||||
this.visits = 0;
|
||||
this.value = 0.0f;
|
||||
}
|
||||
|
||||
public Node(TurnBasedGame<?> state) {
|
||||
this(state, 0L, null);
|
||||
}
|
||||
|
||||
public boolean isFullyExpanded() {
|
||||
return expanded >= children.length;
|
||||
}
|
||||
|
||||
public Node bestUCTChild(float explorationFactor) {
|
||||
int bestChildIndex = -1;
|
||||
float bestScore = Float.NEGATIVE_INFINITY;
|
||||
|
||||
for (int i = 0; i < expanded; i++) {
|
||||
float exploitation = children[i].visits <= 0? 0 : children[i].value / children[i].visits;
|
||||
float exploration = explorationFactor * (float)(Math.sqrt(Math.log(visits) / (children[i].visits + 0.001f)));
|
||||
|
||||
float score = exploitation + exploration;
|
||||
|
||||
if (score > bestScore) {
|
||||
bestChildIndex = i;
|
||||
bestScore = score;
|
||||
}
|
||||
}
|
||||
|
||||
return bestChildIndex >= 0? children[bestChildIndex] : this;
|
||||
}
|
||||
}
|
||||
|
||||
private final int milliseconds;
|
||||
|
||||
public MCTSAI(int milliseconds) {
|
||||
this.milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
public MCTSAI(MCTSAI<T> other) {
|
||||
this.milliseconds = other.milliseconds;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MCTSAI<T> deepCopy() {
|
||||
return new MCTSAI<>(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getMove(T game) {
|
||||
Node root = new Node(game.deepCopy());
|
||||
|
||||
long endTime = System.currentTimeMillis() + milliseconds;
|
||||
|
||||
while (System.currentTimeMillis() <= endTime) {
|
||||
Node node = selection(root);
|
||||
long legalMoves = node.state.getLegalMoves();
|
||||
|
||||
if (legalMoves != 0) {
|
||||
node = expansion(node, legalMoves);
|
||||
}
|
||||
|
||||
float result = 0.0f;
|
||||
|
||||
if (node.state.getLegalMoves() != 0) {
|
||||
result = simulation(node.state, game.getCurrentTurn());
|
||||
}
|
||||
|
||||
backPropagation(node, result);
|
||||
}
|
||||
|
||||
int mostVisitedIndex = -1;
|
||||
int mostVisits = -1;
|
||||
|
||||
for (int i = 0; i < root.expanded; i++) {
|
||||
if (root.children[i].visits > mostVisits) {
|
||||
mostVisitedIndex = i;
|
||||
mostVisits = root.children[i].visits;
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Visit count: " + root.visits);
|
||||
|
||||
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
|
||||
}
|
||||
|
||||
private Node selection(Node node) {
|
||||
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
|
||||
node = node.bestUCTChild(1.41f);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
private Node expansion(Node node, long legalMoves) {
|
||||
for (int i = 0; i < node.expanded; i++) {
|
||||
legalMoves &= ~node.children[i].move;
|
||||
}
|
||||
|
||||
if (legalMoves == 0L) {
|
||||
return node;
|
||||
}
|
||||
|
||||
long move = randomSetBit(legalMoves);
|
||||
|
||||
TurnBasedGame<?> copy = node.state.deepCopy();
|
||||
copy.play(move);
|
||||
|
||||
Node newlyExpanded = new Node(copy, move, node);
|
||||
|
||||
node.children[node.expanded] = newlyExpanded;
|
||||
node.expanded++;
|
||||
|
||||
return newlyExpanded;
|
||||
}
|
||||
|
||||
private float simulation(TurnBasedGame<?> state, int playerIndex) {
|
||||
TurnBasedGame<?> copy = state.deepCopy();
|
||||
long legalMoves = copy.getLegalMoves();
|
||||
PlayResult result = null;
|
||||
|
||||
while (legalMoves != 0) {
|
||||
result = copy.play(randomSetBit(legalMoves));
|
||||
legalMoves = copy.getLegalMoves();
|
||||
}
|
||||
|
||||
if (result.state() == GameState.WIN) {
|
||||
if (result.player() == playerIndex) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
return -1.0f;
|
||||
}
|
||||
|
||||
return -0.2f;
|
||||
}
|
||||
|
||||
private void backPropagation(Node node, float value) {
|
||||
while (node != null) {
|
||||
node.visits++;
|
||||
node.value += value;
|
||||
node = node.parent;
|
||||
}
|
||||
}
|
||||
|
||||
public static long randomSetBit(long value) {
|
||||
Random random = new Random();
|
||||
|
||||
int count = Long.bitCount(value);
|
||||
int target = random.nextInt(count);
|
||||
|
||||
while (true) {
|
||||
int bit = Long.numberOfTrailingZeros(value);
|
||||
if (target == 0) {
|
||||
return 1L << bit;
|
||||
}
|
||||
value &= value - 1;
|
||||
target--;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user