Package quickml.supervised.classifier.splitOnAttribute

Source Code of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifierBuilder

package quickml.supervised.classifier.splitOnAttribute;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.AttributesMap;
import quickml.data.Instance;
import quickml.data.InstanceImpl;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.UpdatablePredictiveModelBuilder;
import quickml.supervised.classifier.decisionTree.tree.ClassificationCounter;

import java.io.Serializable;
import java.util.*;

/**
* Created by ian on 5/29/14.
*/
public class SplitOnAttributeClassifierBuilder implements UpdatablePredictiveModelBuilder<AttributesMap, SplitOnAttributeClassifier> {
    private static final  Logger logger =  LoggerFactory.getLogger(SplitOnAttributeClassifierBuilder.class);

    public static final Double NO_VALUE_PLACEHOLDER = Double.MIN_VALUE;

    private final String attributeKey;
    private final UpdatablePredictiveModelBuilder<AttributesMap, ? extends Classifier> wrappedBuilder;

    private final int minimumAmountTotalCrossData;
    private final double percentCrossData;
    private final Set<String> attributeWhiteList;
    private final int minimumAmountCrossDataPerClassification;

    public SplitOnAttributeClassifierBuilder(String attributeKey, UpdatablePredictiveModelBuilder<AttributesMap, ? extends Classifier> wrappedBuilder, int minimumAmountCrossData, double percentCrossData, Set<String> attributeWhiteList, int minimumAmountCrossDataPerClassification) {

        this.attributeKey = attributeKey;
        this.wrappedBuilder = wrappedBuilder;
        this.minimumAmountTotalCrossData = minimumAmountCrossData;
        this.percentCrossData = percentCrossData;
        this.attributeWhiteList = attributeWhiteList;
        this.minimumAmountCrossDataPerClassification = minimumAmountCrossDataPerClassification;
    }

    @Override
    public SplitOnAttributeClassifier buildPredictiveModel(final Iterable<? extends Instance<AttributesMap>> trainingData) {
        Map<Serializable, ArrayList<Instance<AttributesMap>>> splitTrainingData = splitTrainingData(trainingData);

        Map<Serializable, Classifier> splitModels = Maps.newHashMap();
        for (Map.Entry<Serializable, ArrayList<Instance<AttributesMap>>> trainingDataEntry : splitTrainingData.entrySet()) {
            logger.info("Building predictive model for "+attributeKey+"="+trainingDataEntry.getKey());
            setID(trainingDataEntry.getKey());
            splitModels.put(trainingDataEntry.getKey(), wrappedBuilder.buildPredictiveModel(trainingDataEntry.getValue()));
        }

        logger.info("Building default predictive model");
        setID(null);
        final Classifier defaultPM = wrappedBuilder.buildPredictiveModel(trainingData);
        return new SplitOnAttributeClassifier(attributeKey, splitModels, defaultPM);
    }

    @Override
    public PredictiveModelBuilder<AttributesMap, SplitOnAttributeClassifier> updatable(boolean updatable) {
        this.wrappedBuilder.updatable(updatable);
        return this;
    }

    private Map<Serializable, ArrayList<Instance<AttributesMap>>> splitTrainingData(Iterable<? extends Instance<AttributesMap>> trainingData) {
        Map<Serializable, ArrayList<Instance<AttributesMap>>> splitTrainingData = Maps.newHashMap();
        ArrayList<Instance<AttributesMap>> allData = new ArrayList<>();
        for (Instance<AttributesMap> instance : trainingData) {
            Serializable value = instance.getAttributes().get(attributeKey);
            if (value == null) value = NO_VALUE_PLACEHOLDER;
            ArrayList<Instance<AttributesMap>> splitData = splitTrainingData.get(value);
            if (splitData == null) {
                splitData = Lists.newArrayList();
                splitTrainingData.put(value, splitData);
            }
            splitData.add(instance);
            allData.add(instance);
        }

        crossPollinateData(splitTrainingData, allData);
        return splitTrainingData;
    }

    /*
    * Add data to each split data set based on the desired cross data values. Maintain the same ratio of classifications in the split set by
    * selecting that ratio from outside sets. Only keep the attributes in the supporting instances that in in the white list
    * */
    private void crossPollinateData(Map<Serializable, ArrayList<Instance<AttributesMap>>> splitTrainingData, ArrayList<Instance<AttributesMap>> allData) {
        for(Map.Entry<Serializable, ArrayList<Instance<AttributesMap>>> entry : splitTrainingData.entrySet()) {
            ClassificationCounter splitClassificationCounter = ClassificationCounter.countAll(entry.getValue());
            long amountCrossData = (long) Math.max(splitClassificationCounter.getTotal() * percentCrossData, minimumAmountTotalCrossData);
            Set<Instance<AttributesMap>> crossData = new HashSet<>();
            ClassificationCounter crossDataCount = new ClassificationCounter();
            for(int i = allData.size()-1; i >= 0; i--) {
                Instance<AttributesMap>instance = allData.get(i);
                double classificationRatio = splitClassificationCounter.getCount(instance.getLabel()) / splitClassificationCounter.getTotal();
                double targetCount = Math.max(classificationRatio * amountCrossData, minimumAmountCrossDataPerClassification);
                if(shouldAddInstance(entry.getKey(), instance, crossDataCount, targetCount)) {
                    crossData.add(cleanSupportingData(instance));
                    crossDataCount.addClassification(instance.getLabel(), instance.getWeight());
                }
                if(crossDataCount.getTotal() >= amountCrossData) {
                    break;
                }
            }
            //cross pollinate data
            entry.getValue().addAll(crossData);
        }
    }

    /*
     * Add instances such that the ratio of classifications is unchanged
    * */
    private boolean shouldAddInstance(Serializable attributeValue, Instance<AttributesMap> instance, ClassificationCounter crossDataCount, double targetCount) {
        if (!attributeValue.equals(instance.getAttributes().get(attributeKey))) {
            if (targetCount > crossDataCount.getCount(instance.getLabel())) {
                return true;
            }
        }
        return false;
    }

    private Instance<AttributesMap>cleanSupportingData(Instance<AttributesMap> instance) {
        AttributesMap attributes = AttributesMap.newHashMap();
        for (String key : instance.getAttributes().keySet()) {
            if (attributeWhiteList.isEmpty() || attributeWhiteList.contains(key)) {
                attributes.put(key, instance.getAttributes().get(key));
            }
        }
        return new InstanceImpl(attributes, instance.getLabel(), instance.getWeight());
    }

    @Override
    public void setID(Serializable id) {
        wrappedBuilder.setID(id);
    }

    @Override
    public void updatePredictiveModel(SplitOnAttributeClassifier predictiveModel, Iterable<? extends Instance<AttributesMap>> newData, boolean splitNodes) {
        if (wrappedBuilder instanceof UpdatablePredictiveModelBuilder) {
            Map<Serializable, ArrayList<Instance<AttributesMap>>> splitNewData = splitTrainingData(newData);
            for (Map.Entry<Serializable, ArrayList<Instance<AttributesMap>>> newDataEntry : splitNewData.entrySet()) {
                Classifier pm = predictiveModel.getSplitModels().get(newDataEntry.getKey());
                if(pm == null) {
                    logger.info("Building predictive model for "+attributeKey+"="+newDataEntry.getKey());
                    setID(newDataEntry.getKey());
                    pm = wrappedBuilder.buildPredictiveModel(newDataEntry.getValue());
                    predictiveModel.getSplitModels().put(newDataEntry.getKey(), pm);
                } else {
                    logger.info("Updating predictive model for "+attributeKey+"="+newDataEntry.getKey());
                    ((UpdatablePredictiveModelBuilder) wrappedBuilder).updatePredictiveModel(pm, newDataEntry.getValue(), splitNodes);
                }
            }
            logger.info("Updating default predictive model");
            setID(null);
            ((UpdatablePredictiveModelBuilder) wrappedBuilder).updatePredictiveModel(predictiveModel.getDefaultPM(), newData, splitNodes);
        } else {
            throw new RuntimeException("Cannot update predictive model without UpdatablePredictiveModelBuilder");
        }
    }

    @Override
    public void stripData(SplitOnAttributeClassifier predictiveModel) {
        if (wrappedBuilder instanceof UpdatablePredictiveModelBuilder) {
            for(Classifier pm : predictiveModel.getSplitModels().values()) {
                ((UpdatablePredictiveModelBuilder) wrappedBuilder).stripData(pm);
            }
            ((UpdatablePredictiveModelBuilder) wrappedBuilder).stripData(predictiveModel.getDefaultPM());
        } else {
            throw new RuntimeException("Cannot strip data without UpdatablePredictiveModelBuilder");
        }
    }
}
TOP

Related Classes of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifierBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.