Package quickml.supervised.classifier.decisionTree.scorers

Source Code of quickml.supervised.classifier.decisionTree.scorers.MSEScorer

package quickml.supervised.classifier.decisionTree.scorers;

import quickml.supervised.classifier.decisionTree.Scorer;
import quickml.supervised.classifier.decisionTree.tree.ClassificationCounter;

import java.io.Serializable;
import java.util.Map;

/**
* A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE)
* of a branch existing versus not existing.  The value returned is the MSE
* without the branch minus the MSE with the branch (so higher is better, as
* is required by the scoreSplit() interface.
*/
public class MSEScorer implements Scorer {
    private final double crossValidationInstanceCorrection;

    public MSEScorer(CrossValidationCorrection crossValidationCorrection) {
        if (crossValidationCorrection.equals(CrossValidationCorrection.TRUE)) {
            crossValidationInstanceCorrection = 1.0;
        } else {
            crossValidationInstanceCorrection = 0.0;
        }
    }

    @Override
    public double scoreSplit(final ClassificationCounter a, final ClassificationCounter b) {
        ClassificationCounter parent = ClassificationCounter.merge(a, b);
        double parentMSE = getTotalError(parent) / parent.getTotal();
        double splitMSE = (getTotalError(a) + getTotalError(b)) / (a.getTotal() + b.getTotal());
        return parentMSE - splitMSE;
    }

    private double getTotalError(ClassificationCounter cc) {
        double totalError = 0;
        for (Map.Entry<Serializable, Double> e : cc.getCounts().entrySet()) {
            double error = (cc.getTotal()>0) ? 1.0 - e.getValue()/cc.getTotal() : 0;
            double errorSquared = error*error;
            totalError += errorSquared * e.getValue();
        }
        return totalError;
    }

    public enum CrossValidationCorrection {
        TRUE, FALSE
    }

    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder("MSEScorer{");
        sb.append("cvic=").append(crossValidationInstanceCorrection);
        sb.append('}');
        return sb.toString();
    }
}
TOP

Related Classes of quickml.supervised.classifier.decisionTree.scorers.MSEScorer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.