mirror of
https://github.com/2OOP/pism.git
synced 2026-02-04 10:54:51 +00:00
mcts v1
This commit is contained in:
@@ -15,6 +15,7 @@ import org.toop.app.widget.complex.PlayerInfoWidget;
|
|||||||
import org.toop.app.widget.complex.ViewWidget;
|
import org.toop.app.widget.complex.ViewWidget;
|
||||||
import org.toop.app.widget.popup.ErrorPopup;
|
import org.toop.app.widget.popup.ErrorPopup;
|
||||||
import org.toop.app.widget.tutorial.*;
|
import org.toop.app.widget.tutorial.*;
|
||||||
|
import org.toop.game.players.ai.MCTSAI;
|
||||||
import org.toop.game.players.ai.MiniMaxAI;
|
import org.toop.game.players.ai.MiniMaxAI;
|
||||||
import org.toop.game.players.ai.RandomAI;
|
import org.toop.game.players.ai.RandomAI;
|
||||||
import org.toop.local.AppContext;
|
import org.toop.local.AppContext;
|
||||||
@@ -87,7 +88,7 @@ public class LocalMultiplayerView extends ViewWidget {
|
|||||||
if (information.players[1].isHuman) {
|
if (information.players[1].isHuman) {
|
||||||
players[1] = new LocalPlayer<>(information.players[1].name);
|
players[1] = new LocalPlayer<>(information.players[1].name);
|
||||||
} else {
|
} else {
|
||||||
players[1] = new ArtificialPlayer<>(new MiniMaxAI<BitboardReversi>(6), "MiniMax");
|
players[1] = new ArtificialPlayer<>(new MCTSAI<BitboardReversi>(1000), "MCTS AI");
|
||||||
}
|
}
|
||||||
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
|
if (AppSettings.getSettings().getTutorialFlag() && AppSettings.getSettings().getFirstReversi()) {
|
||||||
new ShowEnableTutorialWidget(
|
new ShowEnableTutorialWidget(
|
||||||
|
|||||||
191
game/src/main/java/org/toop/game/players/ai/MCTSAI.java
Normal file
191
game/src/main/java/org/toop/game/players/ai/MCTSAI.java
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package org.toop.game.players.ai;
|
||||||
|
|
||||||
|
import org.toop.framework.gameFramework.GameState;
|
||||||
|
import org.toop.framework.gameFramework.model.game.PlayResult;
|
||||||
|
import org.toop.framework.gameFramework.model.game.TurnBasedGame;
|
||||||
|
import org.toop.framework.gameFramework.model.player.AbstractAI;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class MCTSAI<T extends TurnBasedGame<T>> extends AbstractAI<T> {
|
||||||
|
private static class Node {
|
||||||
|
public TurnBasedGame<?> state;
|
||||||
|
public long move;
|
||||||
|
|
||||||
|
public Node parent;
|
||||||
|
|
||||||
|
public int expanded;
|
||||||
|
public Node[] children;
|
||||||
|
|
||||||
|
public int visits;
|
||||||
|
public float value;
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state, long move, Node parent) {
|
||||||
|
this.state = state;
|
||||||
|
this.move = move;
|
||||||
|
|
||||||
|
this.parent = parent;
|
||||||
|
|
||||||
|
this.expanded = 0;
|
||||||
|
this.children = new Node[Long.bitCount(state.getLegalMoves())];
|
||||||
|
|
||||||
|
this.visits = 0;
|
||||||
|
this.value = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node(TurnBasedGame<?> state) {
|
||||||
|
this(state, 0L, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isFullyExpanded() {
|
||||||
|
return expanded >= children.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Node bestUCTChild(float explorationFactor) {
|
||||||
|
int bestChildIndex = -1;
|
||||||
|
float bestScore = Float.NEGATIVE_INFINITY;
|
||||||
|
|
||||||
|
for (int i = 0; i < expanded; i++) {
|
||||||
|
float exploitation = children[i].visits <= 0? 0 : children[i].value / children[i].visits;
|
||||||
|
float exploration = explorationFactor * (float)(Math.sqrt(Math.log(visits) / (children[i].visits + 0.001f)));
|
||||||
|
|
||||||
|
float score = exploitation + exploration;
|
||||||
|
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestChildIndex = i;
|
||||||
|
bestScore = score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestChildIndex >= 0? children[bestChildIndex] : this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final int milliseconds;
|
||||||
|
|
||||||
|
public MCTSAI(int milliseconds) {
|
||||||
|
this.milliseconds = milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MCTSAI(MCTSAI<T> other) {
|
||||||
|
this.milliseconds = other.milliseconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MCTSAI<T> deepCopy() {
|
||||||
|
return new MCTSAI<>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getMove(T game) {
|
||||||
|
Node root = new Node(game.deepCopy());
|
||||||
|
|
||||||
|
long endTime = System.currentTimeMillis() + milliseconds;
|
||||||
|
|
||||||
|
while (System.currentTimeMillis() <= endTime) {
|
||||||
|
Node node = selection(root);
|
||||||
|
long legalMoves = node.state.getLegalMoves();
|
||||||
|
|
||||||
|
if (legalMoves != 0) {
|
||||||
|
node = expansion(node, legalMoves);
|
||||||
|
}
|
||||||
|
|
||||||
|
float result = 0.0f;
|
||||||
|
|
||||||
|
if (node.state.getLegalMoves() != 0) {
|
||||||
|
result = simulation(node.state, game.getCurrentTurn());
|
||||||
|
}
|
||||||
|
|
||||||
|
backPropagation(node, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
int mostVisitedIndex = -1;
|
||||||
|
int mostVisits = -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < root.expanded; i++) {
|
||||||
|
if (root.children[i].visits > mostVisits) {
|
||||||
|
mostVisitedIndex = i;
|
||||||
|
mostVisits = root.children[i].visits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("Visit count: " + root.visits);
|
||||||
|
|
||||||
|
return mostVisitedIndex != -1? root.children[mostVisitedIndex].move : randomSetBit(game.getLegalMoves());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node selection(Node node) {
|
||||||
|
while (node.state.getLegalMoves() != 0L && node.isFullyExpanded()) {
|
||||||
|
node = node.bestUCTChild(1.41f);
|
||||||
|
}
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Node expansion(Node node, long legalMoves) {
|
||||||
|
for (int i = 0; i < node.expanded; i++) {
|
||||||
|
legalMoves &= ~node.children[i].move;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (legalMoves == 0L) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
long move = randomSetBit(legalMoves);
|
||||||
|
|
||||||
|
TurnBasedGame<?> copy = node.state.deepCopy();
|
||||||
|
copy.play(move);
|
||||||
|
|
||||||
|
Node newlyExpanded = new Node(copy, move, node);
|
||||||
|
|
||||||
|
node.children[node.expanded] = newlyExpanded;
|
||||||
|
node.expanded++;
|
||||||
|
|
||||||
|
return newlyExpanded;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float simulation(TurnBasedGame<?> state, int playerIndex) {
|
||||||
|
TurnBasedGame<?> copy = state.deepCopy();
|
||||||
|
long legalMoves = copy.getLegalMoves();
|
||||||
|
PlayResult result = null;
|
||||||
|
|
||||||
|
while (legalMoves != 0) {
|
||||||
|
result = copy.play(randomSetBit(legalMoves));
|
||||||
|
legalMoves = copy.getLegalMoves();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result.state() == GameState.WIN) {
|
||||||
|
if (result.player() == playerIndex) {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
return -0.2f;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void backPropagation(Node node, float value) {
|
||||||
|
while (node != null) {
|
||||||
|
node.visits++;
|
||||||
|
node.value += value;
|
||||||
|
node = node.parent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static long randomSetBit(long value) {
|
||||||
|
Random random = new Random();
|
||||||
|
|
||||||
|
int count = Long.bitCount(value);
|
||||||
|
int target = random.nextInt(count);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
int bit = Long.numberOfTrailingZeros(value);
|
||||||
|
if (target == 0) {
|
||||||
|
return 1L << bit;
|
||||||
|
}
|
||||||
|
value &= value - 1;
|
||||||
|
target--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user