How to build a Game Tree using alpha-beta pruning - java

Im trying to build a game tree to my game in order to find my next move.
At first, Im building the tree using a recursive algorithm, and then, to find the best move Im using the alpha - beta pruning algorithm.
I want to build the game tree using the alpha - beta pruning in order to minimize the size of the game tree, but Im having problem writing the algorithm.
Could you help me add the alpha - beta pruning to the expand algorithm?
Here is the expand algorithm:
public void expand(int depth)
{
expand++;
if(depth > 0)
{
this.children = new ArrayList<GameTreeNode>();
List<Move> possibleMoves = this.b.possibleMoves(this.b.turn);
ReversiBoard tmp = null;
for(Move m : possibleMoves)
{
TurnState nextState = (this.state == TurnState.PLUS ? TurnState.MINUS : TurnState.PLUS);
tmp = new ReversiBoard(this.b);
tmp.makeMove(m);
int nextTurn = (turn == PLAYER1 ? PLAYER2 : PLAYER1);
if(tmp.possibleMoves(nextTurn).isEmpty())
nextTurn = turn;
this.children.add(new GameTreeNode(tmp, nextState, m, nextTurn));
for(GameTreeNode child : children)
child.expand(depth - 1);
}
}
}
Here is the alpha - beta pruning code:
int alphaBetaMax( int alpha, int beta, int depthleft ) {
alphaBetaNum++;
if ( depthleft == 0 ) return this.b.evaluate();
for (GameTreeNode tree : this.children) {
bestValue = alphaBetaMin( alpha, beta, depthleft - 1 );
if( bestValue >= beta )
{
bestMove = tree.move;
return beta; // fail hard beta-cutoff
}
if( bestValue > alpha )
alpha = bestValue; // alpha acts like max in MiniMax
}
return alpha;
}
int alphaBetaMin( int alpha, int beta, int depthleft ) {
alphaBetaNum++;
if ( depthleft == 0 ) return -this.b.evaluate();
for ( GameTreeNode tree : this.children) {
bestValue = alphaBetaMax( alpha, beta, depthleft - 1 );
if( bestValue <= alpha )
{
bestMove = tree.move;
return alpha; // fail hard alpha-cutoff
}
if( bestValue < beta )
beta = bestValue; // beta acts like min in MiniMax
}
return beta;
}
public void summonAlphaBeta(int depth)
{
this.bestValue = alphaBetaMax(Integer.MIN_VALUE, Integer.MAX_VALUE, depth);
}
Thank You!

You have two options.
You could just combine the two algorithms by converting your expand method into expandAndReturnMin and expandAndReturnMax methods which each take the alpha and beta values as arguments. Ideally any shared code would be put into a third method to keep your code clean.
Here is some example code for you to consider. In this example I've assumed a static member is storing the best move.
public int bestValue(Board board, int depth, int alpha, int beta, boolean aiPlayer) {
if (depth >= MAX_DEPTH || board.possibleMoves(aiPlayer).isEmpty()) {
return board.getValue();
} else {
for (Move move: board.possibleMoves(aiPlayer) {
int value = bestValue(board.makeMove(move), depth + 1, alpha, beta, !aiPlayer);
if (aiPlayer && value > alpha) {
alpha = value;
bestMove = move;
if (alpha >= beta)
break;
} else if (!aiPlayer && value < beta) {
beta = value;
bestMove = move;
if (beta >= alpha)
break;
}
}
return aiPlayer ? alpha : beta;
}
}
The best initial move is determined by:
board.bestValue(board, 0, Integer.MIN_VALUE, Integer.MAX_VALUE, true);
and then using board.getBestMove().
A more elegant solution would be to store the alpha and beta values in the tree itself. That is very simple: after generating each child node you update the values in the current node. Then if they fall outside the allowed range you can stop generating child nodes. This is the more standard approach and is computationally cheap but makes the nodes use more memory.

Related

Is there a problem with my Negamax algorithm with alpha-beta pruning?

I'm trying to build a chess AI. My negamax function with alpha-beta pruning (ABP) runs much slower (about 8 times) than separate min and max functions also with ABP, though the moves returned are equal.
My board evaluation function always returns a value with respect to the red player, i.e. the higher the better for red. For Negamax only, this value is multiplied by -1 for the black player when returning at depth 0.
My Negamax function:
int alphaBeta(Board board, int depth, int alpha, int beta) {
if (depth <= 0 || board.isGameOver()) { // game over == checkmate/stalemate
int color = board.getCurrPlayer().getAlliance().isRed() ? 1 : -1;
return BoardEvaluator.evaluate(board, depth) * color;
}
int bestVal = Integer.MIN_VALUE + 1;
for (Move move : MoveSorter.simpleSort(board.getCurrPlayer().getLegalMoves())) {
MoveTransition transition = board.getCurrPlayer().makeMove(move);
if (transition.getMoveStatus().isAllowed()) { // allowed == legal && non-suicidal
int val = -alphaBeta(transition.getNextBoard(), depth - 1, -beta, -alpha);
if (val >= beta) {
return val; // fail-soft
}
if (val > bestVal) {
bestVal = val;
alpha = Math.max(alpha, val);
}
}
}
return bestVal;
}
The root call:
-alphaBeta(transition.getNextBoard(), searchDepth - 1,
Integer.MIN_VALUE + 1, Integer.MAX_VALUE); // +1 to avoid overflow when negating
My min and max functions:
int min(Board board, int depth, int alpha, int beta) {
if (depth <= 0 || board.isGameOver()) {
return BoardEvaluator.evaluate(board, depth);
}
int minValue = Integer.MAX_VALUE;
for (Move move : MoveSorter.simpleSort(board.getCurrPlayer().getLegalMoves())) {
MoveTransition transition = board.getCurrPlayer().makeMove(move);
if (transition.getMoveStatus().isAllowed()) {
minValue = Math.min(minValue, max(transition.getNextBoard(), depth - 1, alpha, beta));
beta = Math.min(beta, minValue);
if (alpha >= beta) break; // cutoff
}
}
return minValue;
}
int max(Board board, int depth, int alpha, int beta) {
if (depth <= 0 || board.isGameOver()) {
return BoardEvaluator.evaluate(board, depth);
}
int maxValue = Integer.MIN_VALUE;
for (Move move : MoveSorter.simpleSort(board.getCurrPlayer().getLegalMoves())) {
MoveTransition transition = board.getCurrPlayer().makeMove(move);
if (transition.getMoveStatus().isAllowed()) {
maxValue = Math.max(maxValue, min(transition.getNextBoard(), depth - 1, alpha, beta));
alpha = Math.max(alpha, maxValue);
if (alpha >= beta) break; // cutoff
}
}
return maxValue;
}
The root calls for red and black players respectively:
min(transition.getNextBoard(), searchDepth - 1, Integer.MIN_VALUE, Integer.MAX_VALUE);
max(transition.getNextBoard(), searchDepth - 1, Integer.MIN_VALUE, Integer.MAX_VALUE);
I'm guessing there's a bug with the cutoff in the Negamax function although I followed the pseudocode from here. Any help is appreciated, thanks!
EDIT: alphaBeta() is called about 6 times more than min() and max() combined, while the number of beta cutoffs is only about 2 times more.
Solved. I should have posted my full code for the root calls as well -- didn't realise I wasn't passing in the new value for beta. Alpha/beta was actually being updated in the root method for separate min-max.
Updated root method for Negamax:
Move bestMove = null;
int bestVal = Integer.MIN_VALUE + 1;
for (Move move : MoveSorter.simpleSort(currBoard.getCurrPlayer().getLegalMoves())) {
MoveTransition transition = currBoard.getCurrPlayer().makeMove(move);
if (transition.getMoveStatus().isAllowed()) {
int val = -alphaBeta(transition.getNextBoard(), searchDepth - 1, Integer.MIN_VALUE + 1, -bestVal);
if (val > bestVal) {
bestVal = val;
bestMove = move;
}
}
}
return bestMove;
Apologies for the lack of information provided in my question -- I didn't expect the bug to be there.

Alpha Beta Pruning with Binary Search Tree

I am working through the Minimax algorithm with Alpha-Beta Pruning example found here. In the example, they use an array to implement the search tree. I followed the example, but also tried implementing it with a binary search tree as well. Here are the values I'm using in the tree: 3, 5, 6, 9, 1, 2, 0, -1.
The optimal value at the end should be 5. With the BST implementation, I keep getting 2.
I think this is the problem, but I don't know how to get around it:
I wrote the code to return out of recursion if it sees a leaf node to stop from getting null pointer exceptions when trying to check the next value. But instead, I think it's stopping the search too early (based off of what I see when stepping through the code with the debugger). If I remove the check though, the code fails on a null pointer.
Can someone point me in the right direction? What am I doing wrong?
Here's the code:
public class AlphaBetaMiniMax {
private static BinarySearchTree myTree = new BinarySearchTree();
static int MAX = 1000;
static int MIN = -1000;
static int opt;
public static void main(String[] args) {
//Start constructing the game
AlphaBetaMiniMax demo = new AlphaBetaMiniMax();
//3, 5, 6, 9, 1, 2, 0, -1
demo.myTree.insert(3);
demo.myTree.insert(5);
demo.myTree.insert(6);
demo.myTree.insert(9);
demo.myTree.insert(1);
demo.myTree.insert(2);
demo.myTree.insert(0);
demo.myTree.insert(-1);
//print the tree
System.out.println("Game Tree: ");
demo.myTree.printTree(demo.myTree.root);
//Print the results of the game
System.out.println("\nGame Results:");
//run the minimax algorithm with the following inputs
int optimalVal = demo.minimax(0, myTree.root, true, MAX, MIN);
System.out.println("Optimal Value: " + optimalVal);
}
/**
* #param alpha = 1000
* #param beta = -1000
* #param nodeIndex - the current node
* #param depth - the depth to search
* #param maximizingPlayer - the current player making a move
* #return - the best move for the current player
*/
public int minimax(int depth, MiniMaxNode nodeIndex, boolean maximizingPlayer, double alpha, double beta) {
//Base Case #1: Reached the bottom of the tree
if (depth == 2) {
return nodeIndex.getValue();
}
//Base Case #2: if reached a leaf node, return the value of the current node
if (nodeIndex.getLeft() == null && maximizingPlayer == false) {
return nodeIndex.getValue();
} else if (nodeIndex.getRight() == null && maximizingPlayer == true) {
return nodeIndex.getValue();
}
//Mini-Max Algorithm
if (maximizingPlayer) {
int best = MIN;
//Recur for left and right children
for (int i = 0; i < 2; i++) {
int val = minimax(depth + 1, nodeIndex.getLeft(), false, alpha, beta);
best = Math.max(best, val);
alpha = Math.max(alpha, best);
//Alpha Beta Pruning
if (beta <= alpha) {
break;
}
}
return best;
} else {
int best = MAX;
//Recur for left and right children
for (int i = 0; i < 2; i++) {
int val = minimax(depth + 1, nodeIndex.getRight(), true, alpha, beta);
best = Math.min(best, val);
beta = Math.min(beta, best);
//Alpha Beta Pruning
if (beta <= alpha) {
break;
}
}
return best;
}
}
}
Output:
Game Tree:
-1 ~ 0 ~ 1 ~ 2 ~ 3 ~ 5 ~ 6 ~ 9 ~
Game Results:
Optimal Value: 2
Your problem is your iterations are depending on a loop control of 2, and not a node == null finding for nodeIndex.getRight()(for max) getLeft(for min.)
Remember a tree has
1 head(first level)
2nd level = 2
3rd level = 4
4th 8
and so on. So your algorithm for looping will not even go down 3 levels.
for (int i = 0; i < 2; i++) {
int val = minimax(depth + 1, nodeIndex.getLeft(), false, alpha, beta);
best = Math.max(best, val);
alpha = Math.max(alpha, best);
//Alpha Beta Pruning
if (beta <= alpha) {
break;
}
Change your loops to control iteration correctly and you should find the highest value easily.

Minimax result from random values giving unexpectedly low result

I created the minimax, pvs and alpha-beta algorithms and compared their results using a random tree to traverse. This tree has [2,10] children for each parent with a total depth of 10. Each leaf node has a random value of [0,10].
When i run the tree traversal minimax algorithm the resulting value is usually 2 or 3. This is odd to me as i would have guess it would have given 5 or maybe 4 or 6, but it's always 2 or 3. This is minimax algorithm it should give the max of the min ect which is confusing why it seems to be giving almost the min of the whole tree, fyi i start the algorithms as the maximizing player.
This is the results:
Alpha Beta: 2.0, time: 10.606461 milli seconds
PVS: 2.0, time: 41.119652 milli seconds
Minimax: 2.0, time: 184.492937 milli seconds
This is the source code, excluding the Timer class as that's not relevent to my question.
import testing.utilities.data.Timer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
public class MinimaxAlphaBetaTest {
public static void main(String[] args) {
Node parent = new Node(0.);
int depth = 10;
createTree(parent,depth);
Timer t = new Timer().start();
double ab = alphabeta(parent,depth+1,Double.NEGATIVE_INFINITY,Double.POSITIVE_INFINITY,true);
t.stop();
System.out.println("Alpha Beta: "+ab+", time: "+t.getTime());
t = new Timer().start();
double pv = pvs(parent,depth+1,Double.NEGATIVE_INFINITY,Double.POSITIVE_INFINITY,1);
t.stop();
System.out.println("PVS: "+pv+", time: "+t.getTime());
t = new Timer().start();
double mm = minimax(parent,depth+1,true);
t.stop();
System.out.println("Minimax: "+mm+", time: "+t.getTime());
}
public static void createTree(Node n, int depth){
if(depth == 0) {
n.getChildren().add(new Node((double) randBetween(0, 10)));
return;
}
for (int i = 0; i < randBetween(2,10); i++) {
Node nn = new Node(0.);
n.getChildren().add(nn);
createTree(nn,depth-1);
}
}
private static Random r; // pseudo-random number generator
private static long seed; // pseudo-random number generator seed
// static initializer
static {
// this is how the seed was set in Java 1.4
seed = System.currentTimeMillis();
r = new Random(seed);
}
public static int randBetween(int min, int max){
return r.nextInt(max-min+1)+min;
}
public static double pvs(Node node, int depth, double alpha, double beta, int color){
if(depth == 0 || node.getChildren().isEmpty())
return color*node.getValue();
int i = 0;
double score;
for(Node child : node.getChildren()){
if(i++==0)
score = -pvs(child,depth-1,-beta,-alpha,-color);
else {
score = -pvs(child,depth-1,-alpha-1,-alpha,-color);
if(alpha<score || score<beta)
score = -pvs(child,depth-1,-beta,-score,-color);
}
alpha = Math.max(alpha,score);
if(alpha>=beta)
break;
}
return alpha;
}
public static double alphabeta(Node node, int depth, double alpha, double beta, boolean maximizingPlayer){
if(depth == 0 || node.getChildren().isEmpty())
return node.getValue();
double v = maximizingPlayer ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
for(Node child : node.getChildren()){
if(maximizingPlayer) {
v = Math.max(v, alphabeta(child, depth - 1, alpha, beta, false));
alpha = Math.max(alpha, v);
}else {
v = Math.min(v, alphabeta(child, depth - 1, alpha, beta, true));
beta = Math.min(beta, v);
}
if(beta <= alpha)
break;
}
return v;
}
public static double minimax(Node node, int depth, boolean maximizingPlayer){
if(depth == 0 || node.getChildren().isEmpty())
return node.getValue();
double v = maximizingPlayer ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
for(Node child : node.getChildren()){
if(maximizingPlayer)
v = Math.max(v,minimax(child,depth-1,false));
else
v = Math.min(v,minimax(child,depth-1,true));
}
return v;
}
static class Node{
List<Node> children = new ArrayList<>();
double value;
public Node(double value) {
this.value = value;
}
public List<Node> getChildren() {
return children;
}
public double getValue() {
return value;
}
public void setValue(double value) {
this.value = value;
}
#Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Node node = (Node) o;
return Double.compare(node.value, value) == 0;
}
#Override
public int hashCode() {
return Objects.hash(value);
}
}
}
EDIT
Thanks to Jiri Tousek comment it makes sense now, when run with a depth of an odd number gives a number usually higher than 5 and an even number usually lower than 5, like 11 for instance it produces the following results:
Alpha Beta: 7.0, time: 39.697701 milli seconds
PVS: 7.0, time: 216.849568 milli seconds
Minimax: 7.0, time: 998.207216 milli seconds
Running it further with odd numbers, it looks like at depth 3 the results are from what i've seen (5,6,7,8,9,10), depth 5 the results are from what i've seen (7,8,9), depth 7 the results are from what i've seen 7 or 8, depth 9 the results from what i've seen are 8, and depth 11 i've seen 7 and 8
Whereas even numbers produce {2,4,(2,3,4,5,6)},{6,8,10,(2,3)}
So when running the algorithm and look for results, should it matter who's turn it is?
IE if it's the maximizer's turn then go an odd depth down and if it's the minimzer's turn go an even depth down, any depth produce the ideal move?
Since you're starting with maximum (at the very top) and going 10 levels down, you're gonna choose the minimums at the deepest level.
Since you're choosing minimum of the [2-10] numbers there (not an average), you cannot expect it to be roughly 5 - it will most often be lower. In fact, a chance that the minimum of 6 numbers in range [0-10] will be 5 or higher is roughly (1/2)^6 ~ 1.5%.
The operations then alternate. I'm not able to compute any good probabilities there, but intuitively the effect should be about neutral.
It might help to look at what happens after one alternation of min then max. Let's go with 6 children per node (to make it simple). As already stated, the chance that a result of the "min" phase will be 5 or higher is roughly (1/2)^6. For the "max" phase to result in a number 5 or higher, it's enough for any child to be 5 or higher. There are 6 children, so the chance is about 1 - (1 - (1/2)^6)^6 ~ 9%.
This is already much lower than the initial 50% chance of a 5+ number at the very bottom. Another iteration would then result in 1 - (1 - (0.09)^6)^6 ~ 3e-6 probability of a 5+ number.

Java minimax abalone implementation

I want to implement Minimax in my Abalone game but I don't know how to do it.
To be exact I don't know when the algo need to max or min the player.
If I have understand the logic, I need to min the player and max the AI ?
This is the wikipedia pseudo code
function minimax(node, depth, maximizingPlayer)
if depth = 0 or node is a terminal node
return the heuristic value of node
if maximizingPlayer
bestValue := -∞
for each child of node
val := minimax(child, depth - 1, FALSE))
bestValue := max(bestValue, val);
return bestValue
else
bestValue := +∞
for each child of node
val := minimax(child, depth - 1, TRUE))
bestValue := min(bestValue, val);
return bestValue
(* Initial call for maximizing player *)
minimax(origin, depth, TRUE)
And my implementation
private Integer minimax(Board board, Integer depth, Color current, Boolean maximizingPlayer) {
Integer bestValue;
if (0 == depth)
return ((current == selfColor) ? 1 : -1) * this.evaluateBoard(board, current);
Integer val;
if (maximizingPlayer) {
bestValue = -INF;
for (Move m : board.getPossibleMoves(current)) {
board.apply(m);
val = minimax(board, depth - 1, current, Boolean.FALSE);
bestValue = Math.max(bestValue, val);
board.revert(m);
}
return bestValue;
} else {
bestValue = INF;
for (Move m : board.getPossibleMoves(current)) {
board.apply(m);
val = minimax(board, depth - 1, current, Boolean.TRUE);
bestValue = Math.min(bestValue, val);
board.revert(m);
}
return bestValue;
}
}
And my evaluate function
private Integer evaluateBoard(Board board, Color player) {
return board.ballsCount(player) - board.ballsCount(player.other());
}
It depends on your evaluation function; in your case, assuming the goal is to have more balls on the board than your opponent, the Player would be maximizing & the AI would be minimizing.
The usual method in 2 player games is to always maximize, but negate the value
as it is passed up to the parent.
Your evaluation function is not very usefull for minimax search as it would be constant for most moves of the game. The moves in Abalone are a lot less dramatic than Chess. Try using the sum of distances between all the player's marbles. This function gives minimax something to work with.
You also need to ensure that selfColor is the colour of the player to move when you make the initial call to minimax.
The recursion end could also be written
if (0 == depth)
return this.evaluateBoard(board, selfColor);
Out of scope of the question, but could be relevant to you: I find negamax easier to work with.

Calculating direction based on point offsets

For my tile-based game, I need to calculate direction based on a given point offset (difference between two points). For example, let's say I'm standing at point (10, 4) and I want to move to point (8, 6). The direction I move at is north-west. What would be the best way to calculate this?
Here's me basic implementation in Java.
public int direction(int x, int y) {
if (x > 0) {
if (y > 0) {
return 0; // NE
} else if (y < 0) {
return 1; // SE
} else {
return 2; // E
}
} else if (x < 0) {
if (y > 0) {
return 3; // NW
} else if (y < 0) {
return 4; // SW
} else {
return 5; // W
}
} else {
if (y > 0) {
return 6; // N
} else if (y < 0) {
return 7; // S
} else {
return -1;
}
}
}
Surely it can be optimised or shortened. Any help? Thanks.
I think the easiest to understand way would be making a static array that contains the values for all cases.
// Won't say anything about how much these values make sense
static final int[][] directions = {
{3, 6, 0},
{5, -1, 2}, // -1 for "no direction", feel free to replace
{4, 7, 1}
};
public int direction(int x, int y) {
x = (x < 0) ? 0 : ((x > 0) ? 2 : 1);
y = (y < 0) ? 0 : ((y > 0) ? 2 : 1);
return directions[y][x];
}
Edit: Now it's correct (why are so many languages missing a proper sgn function?)
My answers with if conditions :).
public int direction(int x, int y) {
//0 NE, 1 SE, 2 E, 3 NW, 4 SW, 5 W, 6 N, 7 S, 8 (Same place / Not a direction)
int direction = 0;
if(x < 0){
direction = 3;
}else if(x == 0){
direction = 6;
}
if(y < 0){
direction = direction + 1;
}else if(y == 0){
direction = direction + 2;
}
return direction;
}
define a 2D array to hold all states.
convert x and y to 0, 1 or 2 based on their value (x>0 or x<0 or x ==0)
return the specific index of array.
This is about as short and clean as you can get, if you represent the eight cardinal directions this way, as separate enumerated values. You're choosing between eight distinct return values, so a decision tree with eight leaves is the best you can do.
You might get something a little tidier if you split direction into two components (N-S and E-W), but without knowing more about what you do with direction, we can't know whether that's worth the trouble.
You can receive and return your direction as a Point or something similar (anyway, an (x,y) tuple). So if you're standing in p0 = (10, 4) and want to move to p1 = (8, 6), the result would be (in pseudocode):
norm(p1 - p0) = norm((-2,2)) = (-1,1)
You can calculate the norm of an integer if you divide it by its absolute value. So for a point you calculate the norm of both members. Just bear in mind that (-1,1) is more expressive than 3 and you can operate in an easier fashion with it.
If you need specific operations, you can create your own Java Point class or extend the existing ones in the library.

Categories