package org.cspoker.ai.opponentmodels.weka;
import java.io.IOException;
import java.util.HashMap;
import org.apache.log4j.Logger;
import org.cspoker.client.common.gamestate.GameState;
import org.cspoker.client.common.gamestate.modifiers.AllInState;
import org.cspoker.client.common.gamestate.modifiers.BetState;
import org.cspoker.client.common.gamestate.modifiers.CallState;
import org.cspoker.client.common.gamestate.modifiers.CheckState;
import org.cspoker.client.common.gamestate.modifiers.FoldState;
import org.cspoker.client.common.gamestate.modifiers.RaiseState;
import org.cspoker.common.elements.player.PlayerId;
import org.cspoker.common.util.Util;
import org.cspoker.ai.bots.bot.gametree.action.BetAction;
import org.cspoker.ai.bots.bot.gametree.action.CallAction;
import org.cspoker.ai.bots.bot.gametree.action.CheckAction;
import org.cspoker.ai.bots.bot.gametree.action.FoldAction;
import org.cspoker.ai.bots.bot.gametree.action.RaiseAction;
import org.cspoker.ai.bots.bot.gametree.action.SearchBotAction;
import org.cspoker.ai.bots.bot.gametree.mcts.nodes.INode;
import org.cspoker.ai.bots.bot.gametree.mcts.nodes.InnerNode;
import org.cspoker.ai.opponentmodels.OpponentModel;
import org.cspoker.ai.opponentmodels.weka.ARFFPropositionalizer;
import org.cspoker.ai.opponentmodels.weka.PlayerTrackingVisitor;
import com.google.common.collect.ImmutableList;
/**
* The ActionTrackingVisitor currently is used to observe the game
* and delegate important states to an {@link ARFFPropositionalizer}<br>
*/
public class ActionTrackingVisitor extends PlayerTrackingVisitor {
private final static Logger logger = Logger.getLogger(ARFFPropositionalizer.class);
private class AccuracyData {
double truePositive = 0.0;
double trueNegative = 0.0;
double falsePositive = 0.0;
double falseNegative = 0.0;
}
private HashMap<PlayerId, AccuracyData> accuracyData;
public ActionTrackingVisitor(OpponentModel opponentModel, PlayerId bot) {
super(opponentModel);
try {
this.propz = new ARFFPropositionalizer(bot);
accuracyData = new HashMap<PlayerId, AccuracyData>();
} catch (IOException e) {
e.printStackTrace();
}
}
public ARFFPropositionalizer getPropz() {
return (ARFFPropositionalizer) this.propz;
}
private InnerNode getNode(GameState state) {
try {
return (InnerNode) parentOpponentModel.getChosenNode();
} catch (ClassCastException e) {
return null;
}
}
public void printAccuracy() {
for (PlayerId id : accuracyData.keySet()) {
AccuracyData data = accuracyData.get(id);
System.out.print(id + "\t" + //" Accuracy : " +
(data.trueNegative + data.truePositive) /
(data.trueNegative + data.truePositive + data.falseNegative + data.falsePositive));
System.out.print("\t");
}
System.out.println("");
}
private Prediction getProbability(GameState gameState) {
return getProbability(gameState, 0);
}
/**
* To calculate the accuracy of the opponentmodel we need the probabilities
* of considered actions of opponents when the {@link MCTSBot} calculates
* the best action for him to take. For this we use the {@link INode} that
* contains the action made by the bot. If it is an {@link InnerNode} the
* children contain the probabilities of actions the opponent could make.
* Probalities of raises are grouped to give a probability for the action
* Raise. To consider consecutive actions by opponents we adjust the
* {@link INode} to become the correct child, which is the actual action
* made by the opponent. In case of a raise, the raise amongst the children
* nearest to the actual one is chosen.
* The probabilities are returned as a {@link Prediction}.
*
* @param gameState
* the {@link GameState} for which we want to calculate the
* probability
* @param raiseAmount
* the bet/raise amount, otherwise 0
* @return a {@link Prediction} of the considered action.
*
* <br>
* <br>
* TODO: grouping probalities of raises could be improved.
*/
private Prediction getProbability(GameState gameState, double raiseAmount) {
// This method should only be called after MCTSBot has acted
if (parentOpponentModel.getChosenNode() == null)
return null;
HashMap<Class<?>, SearchBotAction> actions = new HashMap<Class<?>, SearchBotAction>();
HashMap<Class<?>, Double> probs = new HashMap<Class<?>, Double>();
Class<?> cProb = null;
RaiseAction raiseAction = null;
BetAction betAction = null;
String errorStr = "";
InnerNode node = getNode(gameState);
if (node != null) {
errorStr = (">-----------------------------");
errorStr += ("\n" + getPlayerName(gameState) + " State " + gameState.getClass());
ImmutableList<INode> children = node.getChildren();
if (children != null) {
for (INode n : children) {
Class<?> c = n.getLastAction().getAction().getClass();
// Same actions are grouped to make one probability (bet/raise)
if (!probs.containsKey(c))
probs.put(c, n.getLastAction().getProbability());
else
probs.put(c, n.getLastAction().getProbability() + probs.get(c));
actions.put(c, n.getLastAction().getAction());
if (gameState.getClass().equals(
n.getLastAction().getAction()
.getUnwrappedStateAfterAction().getClass()) ||
// TODO: you shouldn't get BetAction in RaiseState (but it does happen somehow...)
(gameState.getClass().equals(RaiseState.class) &&
n.getLastAction().getAction().getClass().equals(BetAction.class))) {// ||
// // TODO: idem with Raise-/BetAction in AllinState (now this situation is ignored)
// (gameState.getClass().equals(AllInState.class) &&
// n.getLastAction().getAction().getClass().equals(BetAction.class))) {
if (cProb == null) {
errorStr += "\n Setting chosen node with action " + n.getLastAction().getAction();
parentOpponentModel.setChosenNode(n);
}
cProb = c;
if (raiseAction == null && c.equals(RaiseAction.class))
raiseAction = (RaiseAction) n.getLastAction().getAction();
else if (betAction == null && c.equals(BetAction.class))
betAction = (BetAction) n.getLastAction().getAction();
}
// Correct child node is chosen for bet/raise
if (cProb != null) {
if (raiseAction != null && c.equals(RaiseAction.class)) {
RaiseAction newRaiseAction = (RaiseAction) n.getLastAction().getAction();
if (Math.abs(newRaiseAction.amount - raiseAmount) <
Math.abs(raiseAction.amount - raiseAmount)) {
raiseAction = newRaiseAction;
errorStr += "\n Setting chosen node with action " + n.getLastAction().getAction();
parentOpponentModel.setChosenNode(n);
}
}
else if (betAction != null && c.equals(BetAction.class)) {
BetAction newBetAction = (BetAction) n.getLastAction().getAction();
if (Math.abs(newBetAction.amount - raiseAmount) <
Math.abs(betAction.amount - raiseAmount)) {
betAction = newBetAction;
errorStr += "\n Setting chosen node with action " + n.getLastAction().getAction();
parentOpponentModel.setChosenNode(n);
}
}
}
errorStr += ("\nState "
+ n.getLastAction().getAction().getUnwrappedStateAfterAction().getClass()
+ " with action "
+ n.getLastAction().getAction()
+ "\t with probability "
+ (double) Math.round(n.getLastAction().getProbability() * 10000) / 100
+ "% and totalProb "
+ (double) Math.round(probs.get(c) * 10000) / 100 + "%");
}
errorStr += ("\n> Chosen child with action " +
parentOpponentModel.getChosenNode().getLastAction().getAction());
} else {
errorStr += ("\nNo children for node with action " +
node.getLastAction().getAction());
}
errorStr += ("\n-----------------------------<");
}
// chosen node of opponentmodel should have changed
SearchBotAction action = parentOpponentModel.getChosenNode().getLastAction().getAction();
if (parentOpponentModel.getChosenNode() == node || cProb == null) {
// System.err.println(str);
return null;
}
// System.out.println(">----------------------------");
// for (Class<?> c : probs.keySet()) {
// if (c.equals(cProb))
// assimilatePrediction(new Prediction(action, 1, probs.get(cProb)));
// else
// assimilatePrediction(new Prediction(actions.get(c), 0, probs.get(c)));
// }
// System.out.println("-----------------------------<");
return new Prediction(action, 1, probs.get(cProb));
}
private void assimilatePrediction(PlayerId id, Prediction p) {
if (p == null || p.getAction() == null) return;
// System.out.println(p + ", TP: " + p.getTruePositive() + ", TN: " + p.getTrueNegative()
// + ", FP: " + p.getFalsePositive() + ", FN: " + p.getFalseNegative());
if (!accuracyData.containsKey(id))
accuracyData.put(id, new AccuracyData());
AccuracyData data = accuracyData.get(id);
data.truePositive += p.getTruePositive();
data.trueNegative += p.getTrueNegative();
data.falsePositive += p.getFalsePositive();
data.falseNegative += p.getFalseNegative();
// printAccuracy();
}
public double getAccuracy(PlayerId id) {
AccuracyData data = accuracyData.get(id);
if (data == null)
return 0.0;
else
return (data.trueNegative + data.truePositive) /
(data.trueNegative + data.truePositive + data.falseNegative + data.falsePositive);
}
@Override
public void visitCallState(CallState callState) {
InnerNode node = getNode(callState);
if (node != null && !callState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(callState);
assimilatePrediction(callState.getNextToAct(), p);
getPropz().logCallProb(callState.getNextToAct(), p);
logger.trace(getPlayerName(callState) + " " + p);
} else {
logger.trace(getPlayerName(callState) + " CallState");
}
propz.signalCall(false, callState.getEvent().getPlayerId());
}
@Override
public void visitRaiseState(RaiseState raiseState) {
InnerNode node = getNode(raiseState);
if (node != null && !raiseState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(raiseState,raiseState.getLargestBet());
assimilatePrediction(raiseState.getNextToAct(), p);
getPropz().logRaiseProb(raiseState.getNextToAct(), p);
logger.trace(getPlayerName(raiseState) +
" Raise " + Util.parseDollars(raiseState.getLargestBet()) +
" - with <" + p + ">");
} else {
logger.trace(getPlayerName(raiseState) + " RaiseState: " + Util.parseDollars(raiseState.getLargestBet()));
}
propz.signalRaise(false, raiseState.getLastEvent().getPlayerId(), raiseState.getLargestBet());
}
@Override
public void visitFoldState(FoldState foldState) {
InnerNode node = getNode(foldState);
if (node != null && !foldState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(foldState);
assimilatePrediction(foldState.getNextToAct(), p);
getPropz().logFoldProb(foldState.getNextToAct(), p);
logger.trace(getPlayerName(foldState) + " " + p);
} else {
logger.trace(getPlayerName(foldState) + " FoldState");
}
propz.signalFold(foldState.getEvent().getPlayerId());
}
@Override
public void visitCheckState(CheckState checkState) {
InnerNode node = getNode(checkState);
if (node != null && !checkState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(checkState);
assimilatePrediction(checkState.getNextToAct(), p);
getPropz().logCheckProb(checkState.getNextToAct(), p);
logger.trace(getPlayerName(checkState) + " " + p);
} else {
logger.trace(getPlayerName(checkState) + " CheckState");
}
propz.signalCheck(checkState.getEvent().getPlayerId());
}
@Override
public void visitBetState(BetState betState) {
InnerNode node = getNode(betState);
if (node != null && !betState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(betState, betState.getEvent().getAmount());
assimilatePrediction(betState.getNextToAct(), p);
getPropz().logBetProb(betState.getNextToAct(), p);
logger.trace(getPlayerName(betState) +
" Bet " + Util.parseDollars(betState.getEvent().getAmount()) +
" - with <" + p + ">");
} else {
logger.trace(getPlayerName(betState) + " BetState: " + Util.parseDollars(betState.getEvent().getAmount()));
}
propz.signalBet(false, betState.getEvent().getPlayerId(), betState.getEvent().getAmount());
}
@Override
public void visitAllInState(AllInState allInState) {
InnerNode node = getNode(allInState);
if (node != null && !allInState.getNextToAct().equals(parentOpponentModel.getBotId())) {
Prediction p = getProbability(allInState, allInState.getEvent().getMovedAmount());
assimilatePrediction(allInState.getNextToAct(), p);
if (p != null) {
if (p.getAction() instanceof CallAction)
getPropz().logCallProb(allInState.getNextToAct(), p);
if (p.getAction() instanceof FoldAction)
getPropz().logFoldProb(allInState.getNextToAct(), p);
if (p.getAction() instanceof RaiseAction)
getPropz().logRaiseProb(allInState.getNextToAct(), p);
if (p.getAction() instanceof CheckAction)
getPropz().logCheckProb(allInState.getNextToAct(), p);
if (p.getAction() instanceof BetAction)
getPropz().logBetProb(allInState.getNextToAct(), p);
}
logger.trace(getPlayerName(allInState) +
" All-in " + Util.parseDollars(allInState.getEvent().getMovedAmount()) +
" - with <" + p + ">");
} else {
logger.trace(getPlayerName(allInState) + " AllInState");
}
propz.signalAllIn(allInState.getEvent().getPlayerId(), allInState.getEvent().getMovedAmount());
}
}