Package org.apache.ctakes.ytex.kernel

Source Code of org.apache.ctakes.ytex.kernel.ImputedFeatureEvaluatorImpl$Parameters

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.ytex.kernel;

import java.io.IOException;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Properties;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;

import javax.sql.DataSource;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.dao.ConceptDao;
import org.apache.ctakes.ytex.kernel.model.ConcRel;
import org.apache.ctakes.ytex.kernel.model.ConceptGraph;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.FeatureEvaluation;
import org.apache.ctakes.ytex.kernel.model.FeatureParentChild;
import org.apache.ctakes.ytex.kernel.model.FeatureRank;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.transaction.PlatformTransactionManager;

import weka.core.ContingencyTables;

/**
* Calculate the mutual information of each concept of a corpus wrt a concept
* graph and classification task (label) and possibly a fold. We calculate the
* following:
* <ul>
* <li>raw mutual information of each concept (infogain). We calculate the joint
* distribution of concepts (X) and document classes (Y), and compute the mutual
* information for each concept.
* <li>mutual information inherited by parents (infogain-parent). For each
* concept in the concept graph, we merge the joint distribution of child
* concepts. This is done recursively.
* <li>mutual information inherited by children from parents (infogain-child).
* We take the top n concepts and assign their children (entire subgraph) the
* mutual info of the parent.
* </ul>
* <p>
* The mutual information of each concept is stored in the feature_rank table.
* The related records in the feature_eval table have the following values:
* <ul>
* <li>type = infogain, infogain-parent, infogain-imputed, infogain-imputed-filt
* <li>feature_set_name = conceptSetName
* <li>param1 = conceptGraphName
* </ul>
*
* How this works in broad strokes:
* <ul>
* <li> {@link #evaluateCorpus(Parameters)} load instances, iterate through
* labels
* <li>
* {@link #evaluateCorpusLabel(Parameters, ConceptGraph, InstanceData, String)}
* load concept - set[document] map for the specified label, iterate through
* folds
* <li>
* {@link #evaluateCorpusFold(Parameters, Map, ConceptGraph, InstanceData, String, Map, int)}
* create raw joint distribution of each concept, compute parent joint
* distributions, assign children mutual info of parents
* <li> {@link #completeJointDistroForFold(Map, Map, Set, Set, String)} computes
* raw joint distribution of each concept
* <li>
* {@link #propagateJointDistribution(Map, Parameters, String, int, ConceptGraph, Map)}
* recursively compute parent joint distribution by merging joint distro of
* children.
* <li>{@link #storeChildConcepts(Parameters, String, int, ConceptGraph)} take
* top ranked parent concepts, assign concepts in subtrees the mutual info of
* parents. Only concepts that exist in the corpus are added (depends on
* computing the infocontent of concepts with CorpusEvaluator)
* </ul>
*
*
* @author vijay
*
*/
public class ImputedFeatureEvaluatorImpl implements ImputedFeatureEvaluator {

  /**
   * fill in map of Concept Id - bin - instance ids
   *
   * @author vijay
   *
   */
  public class ConceptInstanceMapExtractor implements RowCallbackHandler {
    ConceptGraph cg;
    Map<String, Map<String, Set<Long>>> conceptInstanceMap;

    ConceptInstanceMapExtractor(
        Map<String, Map<String, Set<Long>>> conceptInstanceMap,
        ConceptGraph cg) {
      this.cg = cg;
      this.conceptInstanceMap = conceptInstanceMap;
    }

    public void processRow(ResultSet rs) throws SQLException {
      String conceptId = rs.getString(1);
      long instanceId = rs.getLong(2);
      String x = rs.getString(3);
      Map<String, Set<Long>> binInstanceMap = conceptInstanceMap
          .get(conceptId);
      if (binInstanceMap == null) {
        // use the conceptId from the concept to save memory
        binInstanceMap = new HashMap<String, Set<Long>>(2);
        conceptInstanceMap.put(conceptId, binInstanceMap);
      }
      Set<Long> instanceIds = binInstanceMap.get(x);
      if (instanceIds == null) {
        instanceIds = new HashSet<Long>();
        binInstanceMap.put(x, instanceIds);
      }
      instanceIds.add(instanceId);
    }
  }

  /**
   * joint distribution of concept (x) and class (y). The bins for x and y are
   * predetermined. Typical levels for x are 0/1 (absent/present) and -1/0/1
   * (negated/not present/affirmed).
   *
   * @author vijay
   *
   */
  public static class JointDistribution {
    /**
     * merge joint distributions into a single distribution. For each value
     * of Y, the cells for each X bin, except for the xMerge bin, are the
     * intersection of all the instances in each of the corresponding bins.
     * The xMerge bin gets everything that is leftover.
     *
     * @param jointDistros
     *            list of joint distribution tables to merge
     * @param yMargin
     *            map of y val - instance id. this could be calculated on
     *            the fly, but we have this information already.
     * @param xMerge
     *            the x val that contains everything that doesn't land in
     *            any of the other bins.
     * @return
     */
    public static JointDistribution merge(
        List<JointDistribution> jointDistros,
        Map<String, Set<Long>> yMargin, String xMerge) {
      Set<String> xVals = jointDistros.get(0).xVals;
      Set<String> yVals = jointDistros.get(0).yVals;
      JointDistribution mergedDistro = new JointDistribution(xVals, yVals);
      for (String y : yVals) {
        // intersect all bins besides the merge bin
        Set<Long> xMergedInst = mergedDistro.getInstances(xMerge, y);
        // everything comes into the merge bin
        // we take out things that land in other bins
        xMergedInst.addAll(yMargin.get(y));
        // iterate over other bins
        for (String x : xVals) {
          if (!x.equals(xMerge)) {
            Set<Long> intersectIds = mergedDistro
                .getInstances(x, y);
            boolean bFirstIter = true;
            // iterate over all joint distribution tables
            for (JointDistribution distro : jointDistros) {
              if (bFirstIter) {
                // 1st iter - add all
                intersectIds.addAll(distro.getInstances(x, y));
                bFirstIter = false;
              } else {
                // subsequent iteration - intersect
                intersectIds.retainAll(distro
                    .getInstances(x, y));
              }
            }
            // remove from the merge bin
            xMergedInst.removeAll(intersectIds);
          }
        }
      }
      return mergedDistro;
    }

    protected double[][] contingencyTable;
    /**
     * the entropy of X. Calculated once and returned as needed.
     */
    protected Double entropyX = null;
    /**
     * the entropy of X*Y. Calculated once and returned as needed.
     */
    protected Double entropyXY = null;
    /**
     * A y*x table where the cells hold the instance ids. We use the
     * instance ids instead of counts so we can merge the tables.
     */
    protected SortedMap<String, SortedMap<String, Set<Long>>> jointDistroTable;
    /**
     * the possible values of X (e.g. concept)
     */
    protected Set<String> xVals;

    /**
     * the possible values of Y (e.g. text)
     */
    protected Set<String> yVals;

    /**
     * set up the joint distribution table.
     *
     * @param xVals
     *            the possible x values (bins)
     * @param yVals
     *            the possible y values (bins)
     */
    public JointDistribution(Set<String> xVals, Set<String> yVals) {
      this.xVals = xVals;
      this.yVals = yVals;
      jointDistroTable = new TreeMap<String, SortedMap<String, Set<Long>>>();
      for (String yVal : yVals) {
        SortedMap<String, Set<Long>> yMap = new TreeMap<String, Set<Long>>();
        jointDistroTable.put(yVal, yMap);
        for (String xVal : xVals) {
          yMap.put(xVal, new HashSet<Long>());
        }
      }
    }

    public JointDistribution(Set<String> xVals, Set<String> yVals,
        Map<String, Set<Long>> xMargin, Map<String, Set<Long>> yMargin,
        String xLeftover) {
      this.xVals = xVals;
      this.yVals = yVals;
      jointDistroTable = new TreeMap<String, SortedMap<String, Set<Long>>>();
      for (String yVal : yVals) {
        SortedMap<String, Set<Long>> yMap = new TreeMap<String, Set<Long>>();
        jointDistroTable.put(yVal, yMap);
        for (String xVal : xVals) {
          yMap.put(xVal, new HashSet<Long>());
        }
      }
      for (Map.Entry<String, Set<Long>> yEntry : yMargin.entrySet()) {
        // iterate over 'rows' i.e. the class names
        String yName = yEntry.getKey();
        Set<Long> yInst = new HashSet<Long>(yEntry.getValue());
        // iterate over 'columns' i.e. the values of x
        for (Map.Entry<String, Set<Long>> xEntry : xMargin.entrySet()) {
          // copy the instances
          Set<Long> foldXInst = jointDistroTable.get(yName).get(
              xEntry.getKey());
          foldXInst.addAll(xEntry.getValue());
          // keep only the ones that are in this fold
          foldXInst.retainAll(yInst);
          // remove the instances for this value of x from the set of
          // all instances
          yInst.removeAll(foldXInst);
        }
        if (yInst.size() > 0) {
          // add the leftovers to the leftover bin
          jointDistroTable.get(yEntry.getKey()).get(xLeftover)
              .addAll(yInst);
        }
      }

    }

    // /**
    // * add an instance to the joint probability table
    // *
    // * @param x
    // * @param y
    // * @param instanceId
    // */
    // public void addInstance(String x, String y, int instanceId) {
    // // add the current row to the bin matrix
    // SortedMap<String, Set<Integer>> xMap = jointDistroTable.get(y);
    // if (xMap == null) {
    // xMap = new TreeMap<String, Set<Integer>>();
    // jointDistroTable.put(y, xMap);
    // }
    // Set<Integer> instanceSet = xMap.get(x);
    // if (instanceSet == null) {
    // instanceSet = new HashSet<Integer>();
    // xMap.put(x, instanceSet);
    // }
    // instanceSet.add(instanceId);
    // }

    // /**
    // * finalize the joint probability table wrt the specified instances.
    // If
    // * we are doing this per fold, then not all instances are going to be
    // in
    // * each fold. Limit to the instances in the specified fold.
    // * <p>
    // * Also, we might not have filled in all the cells. if necessary, put
    // * instances in the 'leftover' cell, fill it in based on the marginal
    // * distribution of the instances wrt classes.
    // *
    // * @param yMargin
    // * map of values of y to the instances with that value
    // * @param xLeftover
    // * the value of x to assign the the leftover instances
    // */
    // public JointDistribution complete(Map<String, Set<Integer>> xMargin,
    // Map<String, Set<Integer>> yMargin, String xLeftover) {
    // JointDistribution foldDistro = new JointDistribution(this.xVals,
    // this.yVals);
    // for (Map.Entry<String, Set<Integer>> yEntry : yMargin.entrySet()) {
    // // iterate over 'rows' i.e. the class names
    // String yName = yEntry.getKey();
    // Set<Integer> yInst = new HashSet<Integer>(yEntry.getValue());
    // // iterate over 'columns' i.e. the values of x
    // for (Map.Entry<String, Set<Integer>> xEntry : this.jointDistroTable
    // .get(yName).entrySet()) {
    // // copy the instances
    // Set<Integer> foldXInst = foldDistro.jointDistroTable.get(
    // yName).get(xEntry.getKey());
    // foldXInst.addAll(xEntry.getValue());
    // // keep only the ones that are in this fold
    // foldXInst.retainAll(yInst);
    // // remove the instances for this value of x from the set of
    // // all instances
    // yInst.removeAll(foldXInst);
    // }
    // if (yInst.size() > 0) {
    // // add the leftovers to the leftover bin
    // foldDistro.jointDistroTable.get(yEntry.getKey())
    // .get(xLeftover).addAll(yInst);
    // }
    // }
    // return foldDistro;
    // }

    public double[][] getContingencyTable() {
      if (contingencyTable == null) {
        contingencyTable = new double[this.yVals.size()][this.xVals
            .size()];
        int i = 0;
        for (String yVal : yVals) {
          int j = 0;
          for (String xVal : xVals) {
            contingencyTable[i][j] = jointDistroTable.get(yVal)
                .get(xVal).size();
            j++;
          }
          i++;
        }
      }
      return contingencyTable;
    }

    public double getEntropyX() {
      double probs[] = new double[xVals.size()];
      Arrays.fill(probs, 0d);
      if (entropyX == null) {
        double nTotal = 0;
        for (Map<String, Set<Long>> xInstance : this.jointDistroTable
            .values()) {
          int i = 0;
          for (Set<Long> instances : xInstance.values()) {
            double nCell = (double) instances.size();
            nTotal += nCell;
            probs[i] += nCell;
            i++;
          }
        }
        for (int i = 0; i < probs.length; i++)
          probs[i] /= nTotal;
        entropyX = entropy(probs);
      }
      return entropyX;
    }

    public double getEntropyXY() {
      double probs[] = new double[xVals.size() * yVals.size()];
      Arrays.fill(probs, 0d);
      if (entropyXY == null) {
        double nTotal = 0;
        int i = 0;
        for (Map<String, Set<Long>> xInstance : this.jointDistroTable
            .values()) {
          for (Set<Long> instances : xInstance.values()) {
            probs[i] = (double) instances.size();
            nTotal += probs[i];
            i++;
          }
        }
        for (int j = 0; j < probs.length; j++)
          probs[j] /= nTotal;
        entropyXY = entropy(probs);
      }
      return entropyXY;
    }

    public double getInfoGain() {
      return ContingencyTables.entropyOverColumns(getContingencyTable())
          - ContingencyTables
              .entropyConditionedOnRows(getContingencyTable());
    }

    public Set<Long> getInstances(String x, String y) {
      return jointDistroTable.get(y).get(x);
    }

    public double getMutualInformation(double entropyY) {
      return entropyY + this.getEntropyX() - this.getEntropyXY();
    }

    /**
     * print out joint distribution table
     */
    public String toString() {
      StringBuilder b = new StringBuilder();
      b.append(this.getClass().getCanonicalName());
      b.append(" [jointDistro=(");
      Iterator<Entry<String, SortedMap<String, Set<Long>>>> yIter = this.jointDistroTable
          .entrySet().iterator();
      while (yIter.hasNext()) {
        Entry<String, SortedMap<String, Set<Long>>> yEntry = yIter
            .next();
        Iterator<Entry<String, Set<Long>>> xIter = yEntry.getValue()
            .entrySet().iterator();
        while (xIter.hasNext()) {
          Entry<String, Set<Long>> xEntry = xIter.next();
          b.append(xEntry.getValue().size());
          if (xIter.hasNext())
            b.append(", ");
        }
        if (yIter.hasNext())
          b.append("| ");
      }
      b.append(")]");
      return b.toString();
    }
  }

  /**
   * We are passing around quite a few parameters. It gets to be a pain, so
   * put everything in an object.
   *
   * @author vijay
   *
   */
  public static class Parameters {
    String classFeatureQuery;
    String conceptGraphName;
    String conceptSetName;
    String corpusName;
    String freqQuery;
    double imputeWeight;

    String labelQuery;
    MeasureType measure;
    double minInfo;
    Double parentConceptEvalThreshold;
    Integer parentConceptTopThreshold;
    String splitName;
    String xLeftover;
    String xMerge;
    Set<String> xVals;

    public Parameters() {

    }

    public Parameters(Properties props) {
      corpusName = props.getProperty("org.apache.ctakes.ytex.corpusName");
      conceptGraphName = props.getProperty("org.apache.ctakes.ytex.conceptGraphName");
      conceptSetName = props.getProperty("org.apache.ctakes.ytex.conceptSetName");
      splitName = props.getProperty("org.apache.ctakes.ytex.splitName");
      labelQuery = props.getProperty("instanceClassQuery");
      classFeatureQuery = props.getProperty("org.apache.ctakes.ytex.conceptInstanceQuery");
      freqQuery = props.getProperty("org.apache.ctakes.ytex.freqQuery");
      minInfo = Double.parseDouble(props.getProperty("min.info", "1e-4"));
      String xValStr = props.getProperty("org.apache.ctakes.ytex.xVals", "0,1");
      xVals = new HashSet<String>();
      xVals.addAll(Arrays.asList(xValStr.split(",")));
      xLeftover = props.getProperty("org.apache.ctakes.ytex.xLeftover", "0");
      xMerge = props.getProperty("org.apache.ctakes.ytex.xMerge", "1");
      this.measure = MeasureType.valueOf(props.getProperty(
          "org.apache.ctakes.ytex.measure", "INFOGAIN"));
      parentConceptEvalThreshold = FileUtil.getDoubleProperty(props,
          "org.apache.ctakes.ytex.parentConceptEvalThreshold", null);
      parentConceptTopThreshold = parentConceptEvalThreshold == null ? FileUtil
          .getIntegerProperty(props,
              "org.apache.ctakes.ytex.parentConceptTopThreshold", 25) : null;
      imputeWeight = FileUtil.getDoubleProperty(props,
          "org.apache.ctakes.ytex.imputeWeight", 1d);
    }

    public String getClassFeatureQuery() {
      return classFeatureQuery;
    }

    public String getConceptGraphName() {
      return conceptGraphName;
    }

    public String getConceptSetName() {
      return conceptSetName;
    }

    public String getCorpusName() {
      return corpusName;
    }

    public String getFreqQuery() {
      return freqQuery;
    }

    public double getImputeWeight() {
      return imputeWeight;
    }

    public String getLabelQuery() {
      return labelQuery;
    }

    public MeasureType getMeasure() {
      return measure;
    }

    public double getMinInfo() {
      return minInfo;
    }

    public Double getParentConceptEvalThreshold() {
      return parentConceptEvalThreshold;
    }

    public Integer getParentConceptTopThreshold() {
      return parentConceptTopThreshold;
    }

    public String getSplitName() {
      return splitName;
    }

    public String getxLeftover() {
      return xLeftover;
    }

    public String getxMerge() {
      return xMerge;
    }

    public Set<String> getxVals() {
      return xVals;
    }
  }

  // /**
  // * iterates through query results and computes infogain
  // *
  // * @author vijay
  // *
  // */
  // public class JointDistroExtractor implements RowCallbackHandler {
  // /**
  // * key - fold
  // * <p/>
  // * value - map of concept id - joint distribution
  // */
  // private Map<String, JointDistribution> jointDistroMap;
  // private Set<String> xVals;
  // private Set<String> yVals;
  // private Map<Integer, String> instanceClassMap;
  //
  // public JointDistroExtractor(
  // Map<String, JointDistribution> jointDistroMap,
  // Set<String> xVals, Set<String> yVals,
  // Map<Integer, String> instanceClassMap) {
  // super();
  // this.xVals = xVals;
  // this.yVals = yVals;
  // this.jointDistroMap = jointDistroMap;
  // this.instanceClassMap = instanceClassMap;
  // }
  //
  // public void processRow(ResultSet rs) throws SQLException {
  // int instanceId = rs.getInt(1);
  // String conceptId = rs.getString(2);
  // String x = rs.getString(3);
  // String y = instanceClassMap.get(instanceId);
  // JointDistribution distro = jointDistroMap.get(conceptId);
  // if (distro == null) {
  // distro = new JointDistribution(xVals, yVals);
  // jointDistroMap.put(conceptId, distro);
  // }
  // distro.addInstance(x, y, instanceId);
  // }
  // }

  private static final Log log = LogFactory
      .getLog(ImputedFeatureEvaluatorImpl.class);

  protected static double entropy(double[] classProbs) {
    double entropy = 0;
    double log2 = Math.log(2);
    for (double prob : classProbs) {
      if (prob > 0)
        entropy += prob * Math.log(prob) / log2;
    }
    return entropy * -1;
  }

  /**
   * calculate entropy from a list/array of probabilities
   *
   * @param classProbs
   * @return
   */
  protected static double entropy(Iterable<Double> classProbs) {
    double entropy = 0;
    double log2 = Math.log(2);
    for (double prob : classProbs) {
      if (prob > 0)
        entropy += prob * Math.log(prob) / log2;
    }
    return entropy * -1;
  }

  @SuppressWarnings("static-access")
  public static void main(String args[]) throws ParseException, IOException {
    Options options = new Options();
    options.addOption(OptionBuilder
        .withArgName("prop")
        .hasArg()
        .isRequired()
        .withDescription(
            "property file with queries and other parameters. todo desc")
        .create("prop"));
    try {
      CommandLineParser parser = new GnuParser();
      CommandLine line = parser.parse(options, args);
      if (!KernelContextHolder.getApplicationContext()
          .getBean(ImputedFeatureEvaluator.class)
          .evaluateCorpus(line.getOptionValue("prop"))) {
        printHelp(options);
      }
    } catch (ParseException pe) {
      printHelp(options);
    }
  }

  private static void printHelp(Options options) {
    HelpFormatter formatter = new HelpFormatter();
    formatter
        .printHelp(
            "java "
                + ImputedFeatureEvaluatorImpl.class.getName()
                + " calculate raw, propagated, and imputed infogain for each feature",
            options);
  }

  protected ClassifierEvaluationDao classifierEvaluationDao;

  protected ConceptDao conceptDao;

  private InfoContentEvaluator infoContentEvaluator;

  protected JdbcTemplate jdbcTemplate;

  protected KernelUtil kernelUtil;
  protected NamedParameterJdbcTemplate namedParamJdbcTemplate;

  protected PlatformTransactionManager transactionManager;

  private Properties ytexProperties = null;

  /**
   * recursively add children of cr to childConcepts
   *
   * @param childConcepts
   * @param cr
   */
  private void addSubtree(Set<String> childConcepts, ConcRel cr) {
    childConcepts.add(cr.getConceptID());
    for (ConcRel crc : cr.getChildren()) {
      addSubtree(childConcepts, crc);
    }
  }

  private JointDistribution calcMergedJointDistribution(
      Map<String, JointDistribution> conceptJointDistroMap,
      Map<String, Integer> conceptDistMap, ConcRel cr,
      Map<String, JointDistribution> rawJointDistroMap,
      Map<String, Set<Long>> yMargin, String xMerge, double minInfo,
      List<String> path) {
    if (conceptJointDistroMap.containsKey(cr.getConceptID())) {
      return conceptJointDistroMap.get(cr.getConceptID());
    } else {
      List<JointDistribution> distroList = new ArrayList<JointDistribution>(
          cr.getChildren().size() + 1);
      int distance = -1;
      // if this concept is in the raw joint distro map, add it to the
      // list of joint distributions to merge
      if (rawJointDistroMap.containsKey(cr.getConceptID())) {
        JointDistribution rawJointDistro = rawJointDistroMap.get(cr
            .getConceptID());
        distroList.add(rawJointDistro);
        distance = 0;
      }
      // get the joint distributions of children
      for (ConcRel crc : cr.getChildren()) {
        List<String> pathChild = new ArrayList<String>(path.size() + 1);
        pathChild.addAll(path);
        pathChild.add(crc.getConceptID());
        // recurse - get joint distribution of children
        JointDistribution jdChild = calcMergedJointDistribution(
            conceptJointDistroMap, conceptDistMap, crc,
            rawJointDistroMap, yMargin, xMerge, minInfo, pathChild);
        if (jdChild != null) {
          distroList.add(jdChild);
          if (distance != 0) {
            // look at children's distance from raw data, add 1
            int distChild = conceptDistMap.get(crc.getConceptID());
            if (distance == -1 || (distChild + 1) < distance) {
              distance = distChild + 1;
            }
          }
        }
      }
      // merge the joint distributions
      JointDistribution mergedDistro;
      if (distroList.size() > 0) {
        if (distroList.size() == 1) {
          // only one joint distro - trivial merge
          mergedDistro = distroList.get(0);
        } else {
          // multiple joint distros - merge them into a new one
          mergedDistro = JointDistribution.merge(distroList, yMargin,
              xMerge);
        }
        // if (log.isDebugEnabled()) {
        // log.debug("path = " + path + ", distroList = " + distroList
        // + ", distro = " + mergedDistro);
        // }
      } else {
        // no joint distros to merge - null
        mergedDistro = null;
      }
      // save this in the map
      conceptJointDistroMap.put(cr.getConceptID(), mergedDistro);
      if (distance > -1)
        conceptDistMap.put(cr.getConceptID(), distance);
      return mergedDistro;
    }
  }

  /**
   *
   */
  private double calculateFoldEntropy(Map<String, Set<Long>> classCountMap) {
    int total = 0;
    List<Double> classProbs = new ArrayList<Double>(classCountMap.size());
    // calculate total number of instances in this fold
    for (Set<Long> instances : classCountMap.values()) {
      total += instances.size();
    }
    // calculate per-class probability in this fold
    for (Set<Long> instances : classCountMap.values()) {
      classProbs.add((double) instances.size() / (double) total);
    }
    return entropy(classProbs);
  }

  /**
   * finalize the joint distribution tables wrt a fold.
   *
   * @param jointDistroMap
   * @param yMargin
   * @param yVals
   * @param xVals
   * @param xLeftover
   */
  private Map<String, JointDistribution> completeJointDistroForFold(
      Map<String, Map<String, Set<Long>>> conceptInstanceMap,
      Map<String, Set<Long>> yMargin, Set<String> xVals,
      Set<String> yVals, String xLeftover) {
    //
    Map<String, JointDistribution> foldJointDistroMap = new HashMap<String, JointDistribution>(
        conceptInstanceMap.size());
    for (Map.Entry<String, Map<String, Set<Long>>> conceptInstance : conceptInstanceMap
        .entrySet()) {
      foldJointDistroMap.put(
          conceptInstance.getKey(),
          new JointDistribution(xVals, yVals, conceptInstance
              .getValue(), yMargin, xLeftover));
    }
    return foldJointDistroMap;
  }

  /**
   * delete the feature evaluations before we insert them
   *
   * @param params
   * @param label
   * @param foldId
   */
  private void deleteFeatureEval(Parameters params, String label, int foldId) {

    for (String type : new String[] { params.getMeasure().getName(),
        params.getMeasure().getName() + SUFFIX_PROP,
        params.getMeasure().getName() + SUFFIX_IMPUTED,
        params.getMeasure().getName() + SUFFIX_IMPUTED_FILTERED })
      this.classifierEvaluationDao.deleteFeatureEvaluation(
          params.getCorpusName(), params.getConceptSetName(), label,
          type, foldId, 0d, params.getConceptGraphName());
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.ctakes.ytex.kernel.CorpusLabelEvaluator#evaluateCorpus(java.lang.String,
   * java.lang.String, java.lang.String, java.lang.String, java.lang.String,
   * java.lang.Double, java.util.Set, java.lang.String, java.lang.String)
   */
  @Override
  public boolean evaluateCorpus(final Parameters params) {
    if (!(params.getCorpusName() != null
        && params.getConceptGraphName() != null
        && params.getLabelQuery() != null && params
        .getClassFeatureQuery() != null))
      return false;
    ConceptGraph cg = conceptDao.getConceptGraph(params
        .getConceptGraphName());
    InstanceData instanceData = kernelUtil.loadInstances(params
        .getLabelQuery());
    for (String label : instanceData.getLabelToInstanceMap().keySet()) {
      evaluateCorpusLabel(params, cg, instanceData, label);
    }
    return true;
  }

  @Override
  public boolean evaluateCorpus(String propFile) throws IOException {
    Properties props = new Properties();
    // put org.apache.ctakes.ytex properties in props
    props.putAll(this.getYtexProperties());
    // override org.apache.ctakes.ytex properties with propfile
    props.putAll(FileUtil.loadProperties(propFile, true));
    return this.evaluateCorpus(new Parameters(props));
  }

  private void evaluateCorpusFold(Parameters params,
      Map<String, Set<Long>> yMargin, ConceptGraph cg,
      InstanceData instanceData, String label,
      Map<String, Map<String, Set<Long>>> conceptInstanceMap, int foldId) {
    if (log.isInfoEnabled())
      log.info("evaluateCorpusFold() label = " + label + ", fold = "
          + foldId);
    deleteFeatureEval(params, label, foldId);

    // get the entropy of Y for this fold
    double yEntropy = this.calculateFoldEntropy(yMargin);
    // get the joint distribution of concepts and instances
    Map<String, JointDistribution> rawJointDistro = this
        .completeJointDistroForFold(conceptInstanceMap, yMargin, params
            .getxVals(),
            instanceData.getLabelToClassMap().get(label), params
                .getxLeftover());
    List<FeatureRank> listRawRanks = new ArrayList<FeatureRank>(
        rawJointDistro.size());
    FeatureEvaluation feRaw = saveFeatureEvaluation(rawJointDistro, params,
        label, foldId, yEntropy, "", listRawRanks);
    // propagate across graph and save
    propagateJointDistribution(rawJointDistro, params, label, foldId, cg,
        yMargin);
    // store children of top concepts
    storeChildConcepts(listRawRanks, params, label, foldId, cg, true);
    storeChildConcepts(listRawRanks, params, label, foldId, cg, false);
  }

  /**
   * evaluate corpus on label
   *
   * @param classFeatureQuery
   * @param minInfo
   * @param xVals
   * @param xLeftover
   * @param xMerge
   * @param eval
   * @param cg
   * @param instanceData
   * @param label
   * @param parentConceptTopThreshold
   * @param parentConceptEvalThreshold
   */
  private void evaluateCorpusLabel(Parameters params, ConceptGraph cg,
      InstanceData instanceData, String label) {
    if (log.isInfoEnabled())
      log.info("evaluateCorpusLabel() label = " + label);
    Map<String, Map<String, Set<Long>>> conceptInstanceMap = loadConceptInstanceMap(
        params.getClassFeatureQuery(), cg, label);
    for (int run : instanceData.getLabelToInstanceMap().get(label).keySet()) {
      for (int fold : instanceData.getLabelToInstanceMap().get(label)
          .get(run).keySet()) {
        int foldId = this.getFoldId(params, label, run, fold);
        // evaluate for the specified fold training set
        // construct map of class - [instance ids]
        Map<String, Set<Long>> yMargin = getFoldYMargin(instanceData,
            label, run, fold);
        evaluateCorpusFold(params, yMargin, cg, instanceData, label,
            conceptInstanceMap, foldId);
      }
    }
  }

  public ClassifierEvaluationDao getClassifierEvaluationDao() {
    return classifierEvaluationDao;
  }

  public ConceptDao getConceptDao() {
    return conceptDao;
  }

  public DataSource getDataSource(DataSource ds) {
    return this.jdbcTemplate.getDataSource();
  }

  private int getFoldId(Parameters params, String label, int run, int fold) {
    // figure out fold id
    int foldId = 0;
    if (run > 0 && fold > 0) {
      CrossValidationFold cvFold = this.classifierEvaluationDao
          .getCrossValidationFold(params.getCorpusName(),
              params.getSplitName(), label, run, fold);
      if (cvFold != null) {
        foldId = cvFold.getCrossValidationFoldId();
      } else {
        log.warn("could not find cv fold, name="
            + params.getCorpusName() + ", run=" + run + ", fold="
            + fold);
      }
    }
    return foldId;
  }

  private Map<String, Set<Long>> getFoldYMargin(InstanceData instanceData,
      String label, int run, int fold) {
    Map<Long, String> instanceClassMap = instanceData
        .getLabelToInstanceMap().get(label).get(run).get(fold)
        .get(true);
    Map<String, Set<Long>> yMargin = new HashMap<String, Set<Long>>();
    for (Map.Entry<Long, String> instanceClass : instanceClassMap
        .entrySet()) {
      Set<Long> instanceIds = yMargin.get(instanceClass.getValue());
      if (instanceIds == null) {
        instanceIds = new HashSet<Long>();
        yMargin.put(instanceClass.getValue(), instanceIds);
      }
      instanceIds.add(instanceClass.getKey());
    }
    return yMargin;
  }

  public InfoContentEvaluator getInfoContentEvaluator() {
    return infoContentEvaluator;
  }

  public KernelUtil getKernelUtil() {
    return kernelUtil;
  }

  public Properties getYtexProperties() {
    return ytexProperties;
  }

  private FeatureEvaluation initFeatureEval(Parameters params, String label,
      int foldId, String type) {
    FeatureEvaluation feval = new FeatureEvaluation();
    feval.setCorpusName(params.getCorpusName());
    feval.setLabel(label);
    feval.setCrossValidationFoldId(foldId);
    feval.setParam2(params.getConceptGraphName());
    feval.setEvaluationType(type);
    feval.setFeatureSetName(params.getConceptSetName());
    return feval;
  }

  /**
   * load the map of concept - instances
   *
   * @param classFeatureQuery
   * @param cg
   * @param label
   * @return
   */
  private Map<String, Map<String, Set<Long>>> loadConceptInstanceMap(
      String classFeatureQuery, ConceptGraph cg, String label) {
    Map<String, Map<String, Set<Long>>> conceptInstanceMap = new HashMap<String, Map<String, Set<Long>>>();
    Map<String, Object> args = new HashMap<String, Object>(1);
    if (label != null && label.length() > 0) {
      args.put("label", label);
    }
    ConceptInstanceMapExtractor ex = new ConceptInstanceMapExtractor(
        conceptInstanceMap, cg);
    this.namedParamJdbcTemplate.query(classFeatureQuery, args, ex);
    return conceptInstanceMap;
  }

  /**
   * 'complete' the joint distribution tables wrt a fold (yMargin). propagate
   * the joint distribution of all concepts recursively.
   *
   * @param rawJointDistroMap
   * @param labelEval
   * @param cg
   * @param yMargin
   * @param xMerge
   * @param minInfo
   */
  private FeatureEvaluation propagateJointDistribution(
      Map<String, JointDistribution> rawJointDistroMap,
      Parameters params, String label, int foldId, ConceptGraph cg,
      Map<String, Set<Long>> yMargin) {
    // get the entropy of Y for this fold
    double yEntropy = this.calculateFoldEntropy(yMargin);
    // allocate a map to hold the results of the propagation across the
    // concept graph
    Map<String, JointDistribution> conceptJointDistroMap = new HashMap<String, JointDistribution>(
        cg.getConceptMap().size());
    Map<String, Integer> conceptDistMap = new HashMap<String, Integer>();
    // recurse
    calcMergedJointDistribution(conceptJointDistroMap, conceptDistMap, cg
        .getConceptMap().get(cg.getRoot()), rawJointDistroMap, yMargin,
        params.getxMerge(), params.getMinInfo(),
        Arrays.asList(new String[] { cg.getRoot() }));
    List<FeatureRank> listPropRanks = new ArrayList<FeatureRank>(
        conceptJointDistroMap.size());
    return this.saveFeatureEvaluation(conceptJointDistroMap, params, label,
        foldId, yEntropy, SUFFIX_PROP, listPropRanks);
  }

  private List<FeatureRank> rank(MeasureType measureType,
      FeatureEvaluation fe,
      Map<String, JointDistribution> rawJointDistro, double yEntropy,
      List<FeatureRank> featureRankList) {
    for (Map.Entry<String, JointDistribution> conceptJointDistro : rawJointDistro
        .entrySet()) {
      JointDistribution d = conceptJointDistro.getValue();
      if (d != null) {
        double evaluation;
        if (MeasureType.MUTUALINFO.equals(measureType)) {
          evaluation = d.getMutualInformation(yEntropy);
        } else {
          evaluation = d.getInfoGain();
        }
        if (evaluation > 1e-3) {
          FeatureRank r = new FeatureRank(fe,
              conceptJointDistro.getKey(), evaluation);
          featureRankList.add(r);
        }
      }
    }
    return FeatureRank.sortFeatureRankList(featureRankList,
        new FeatureRank.FeatureRankDesc());
  }

  private FeatureEvaluation saveFeatureEvaluation(
      Map<String, JointDistribution> rawJointDistro, Parameters params,
      String label, int foldId, double yEntropy, String suffix,
      List<FeatureRank> listRawRanks) {
    FeatureEvaluation fe = initFeatureEval(params, label, foldId, params
        .getMeasure().getName() + suffix);
    this.classifierEvaluationDao.saveFeatureEvaluation(
        fe,
        rank(params.getMeasure(), fe, rawJointDistro, yEntropy,
            listRawRanks));
    return fe;
  }

  public void setClassifierEvaluationDao(
      ClassifierEvaluationDao classifierEvaluationDao) {
    this.classifierEvaluationDao = classifierEvaluationDao;
  }

  public void setConceptDao(ConceptDao conceptDao) {
    this.conceptDao = conceptDao;
  }

  public void setDataSource(DataSource ds) {
    this.jdbcTemplate = new JdbcTemplate(ds);
    this.namedParamJdbcTemplate = new NamedParameterJdbcTemplate(ds);
  }

  // private CorpusLabelEvaluation initCorpusLabelEval(CorpusEvaluation eval,
  // String label, String splitName, int run, int fold) {
  // Integer foldId = getFoldId(eval, label, splitName, run, fold);
  // // see if the labelEval is already there
  // CorpusLabelEvaluation labelEval = corpusDao.getCorpusLabelEvaluation(
  // eval.getCorpusName(), eval.getConceptGraphName(),
  // eval.getConceptSetName(), label, foldId);
  // if (labelEval == null) {
  // // not there - add it
  // labelEval = new CorpusLabelEvaluation();
  // labelEval.setCorpus(eval);
  // labelEval.setFoldId(foldId);
  // labelEval.setLabel(label);
  // corpusDao.addCorpusLabelEval(labelEval);
  // }
  // return labelEval;
  // }

  public void setInfoContentEvaluator(
      InfoContentEvaluator infoContentEvaluator) {
    this.infoContentEvaluator = infoContentEvaluator;
  }

  // /**
  // * create the corpusEvaluation if it doesn't exist
  // *
  // * @param corpusName
  // * @param conceptGraphName
  // * @param conceptSetName
  // * @return
  // */
  // private CorpusEvaluation initEval(String corpusName,
  // String conceptGraphName, String conceptSetName) {
  // CorpusEvaluation eval = this.corpusDao.getCorpus(corpusName,
  // conceptGraphName, conceptSetName);
  // if (eval == null) {
  // eval = new CorpusEvaluation();
  // eval.setConceptGraphName(conceptGraphName);
  // eval.setConceptSetName(conceptSetName);
  // eval.setCorpusName(corpusName);
  // this.corpusDao.addCorpus(eval);
  // }
  // return eval;
  // }

  public void setKernelUtil(KernelUtil kernelUtil) {
    this.kernelUtil = kernelUtil;
  }

  //
  // private void saveLabelStatistic(String conceptID,
  // JointDistribution distroMerged, JointDistribution distroRaw,
  // CorpusLabelEvaluation labelEval, double yEntropy, double minInfo,
  // int distance) {
  // double miMerged = distroMerged.getMutualInformation(yEntropy);
  // double miRaw = distroRaw != null ? distroRaw
  // .getMutualInformation(yEntropy) : 0;
  // if (miMerged > minInfo || miRaw > minInfo) {
  // ConceptLabelStatistic stat = new ConceptLabelStatistic();
  // stat.setCorpusLabel(labelEval);
  // stat.setMutualInfo(miMerged);
  // if (distroRaw != null)
  // stat.setMutualInfoRaw(miRaw);
  // stat.setConceptId(conceptID);
  // stat.setDistance(distance);
  // this.corpusDao.addLabelStatistic(stat);
  // }
  // }

  public void setYtexProperties(Properties ytexProperties) {
    this.ytexProperties = ytexProperties;
  }

  /**
   * save the children of the 'top' parent concepts.
   *
   * @param labelEval
   * @param parentConceptTopThreshold
   * @param parentConceptEvalThreshold
   * @param cg
   * @param bAll
   *            impute to all concepts/concepts actually in corpus. if we are
   *            imputing to all concepts, filter by infocontent (this includes
   *            hypernyms of concepts in the corpus). else only impute to
   *            conrete concepts in the corpus
   */
  public void storeChildConcepts(List<FeatureRank> listRawRanks,
      Parameters params, String label, int foldId, ConceptGraph cg,
      boolean bAll) {
    // only include concepts that actually occur in the corpus
    Map<String, Double> conceptICMap = bAll ? classifierEvaluationDao
        .getInfoContent(params.getCorpusName(),
            params.getConceptGraphName(),
            params.getConceptSetName()) : this.infoContentEvaluator
        .getFrequencies(params.getFreqQuery());
    // get the raw feature evaluations. The imputed feature evaluation is a
    // mixture of the parent feature eval and the raw feature eval.
    Map<String, Double> conceptRawEvalMap = new HashMap<String, Double>(
        listRawRanks.size());
    for (FeatureRank r : listRawRanks) {
      conceptRawEvalMap.put(r.getFeatureName(), r.getEvaluation());
    }
    // this map will get filled with the links between parent and child
    // concepts for imputation
    Map<FeatureRank, Set<FeatureRank>> childParentMap = bAll ? null
        : new HashMap<FeatureRank, Set<FeatureRank>>();
    // .getFeatureRankEvaluations(params.getCorpusName(),
    // params.getConceptSetName(), null,
    // InfoContentEvaluator.INFOCONTENT, 0,
    // params.getConceptGraphName());
    // get the top parent concepts - use either top N, or those with a
    // cutoff greater than the specified threshold
    // List<ConceptLabelStatistic> listConceptStat =
    // parentConceptTopThreshold != null ? this.corpusDao
    // .getTopCorpusLabelStat(labelEval, parentConceptTopThreshold)
    // : this.corpusDao.getThresholdCorpusLabelStat(labelEval,
    // parentConceptMutualInfoThreshold);
    String propagatedType = params.getMeasure().getName() + SUFFIX_PROP;
    List<FeatureRank> listConceptStat = params
        .getParentConceptTopThreshold() != null ? this.classifierEvaluationDao
        .getTopFeatures(params.getCorpusName(),
            params.getConceptSetName(), label, propagatedType,
            foldId, 0, params.getConceptGraphName(),
            params.getParentConceptTopThreshold())
        : this.classifierEvaluationDao.getThresholdFeatures(
            params.getCorpusName(), params.getConceptSetName(),
            label, propagatedType, foldId, 0,
            params.getConceptGraphName(),
            params.getParentConceptEvalThreshold());
    FeatureEvaluation fe = this.initFeatureEval(params, label, foldId,
        params.getMeasure().getName()
            + (bAll ? SUFFIX_IMPUTED : SUFFIX_IMPUTED_FILTERED));
    // map of concept id to children and the 'best' statistic
    Map<String, FeatureRank> mapChildConcept = new HashMap<String, FeatureRank>();
    // get all the children of the parent concepts
    for (FeatureRank parentConcept : listConceptStat) {
      updateChildren(parentConcept, mapChildConcept, fe, cg,
          conceptICMap, conceptRawEvalMap, childParentMap,
          params.getImputeWeight(), params.getMinInfo());
    }
    // save the imputed feature ranks
    List<FeatureRank> features = new ArrayList<FeatureRank>(
        mapChildConcept.values());
    FeatureRank.sortFeatureRankList(features,
        new FeatureRank.FeatureRankDesc());
    this.classifierEvaluationDao.saveFeatureEvaluation(fe, features);
    if (!bAll) {
      // save the parent-child links
      for (Map.Entry<FeatureRank, Set<FeatureRank>> childParentEntry : childParentMap
          .entrySet()) {
        FeatureRank child = childParentEntry.getKey();
        for (FeatureRank parent : childParentEntry.getValue()) {
          FeatureParentChild parchd = new FeatureParentChild();
          parchd.setFeatureRankParent(parent);
          parchd.setFeatureRankChild(child);
          this.classifierEvaluationDao.saveFeatureParentChild(parchd);
        }
      }
    }
  }

  /**
   * add the children of parentConcept to mapChildConcept. Assign the child
   * the best mutual information value of the parent.
   *
   * @param parentConcept
   * @param mapChildConcept
   * @param labelEval
   * @param cg
   * @param parentChildMap
   * @param conceptRawEvalMap
   */
  private void updateChildren(FeatureRank parentConcept,
      Map<String, FeatureRank> mapChildConcept, FeatureEvaluation fe,
      ConceptGraph cg, Map<String, Double> conceptICMap,
      Map<String, Double> conceptRawEvalMap,
      Map<FeatureRank, Set<FeatureRank>> childParentMap,
      double imputeWeight, double minInfo) {
    ConcRel cr = cg.getConceptMap().get(parentConcept.getFeatureName());
    Set<String> childConcepts = new HashSet<String>();
    addSubtree(childConcepts, cr);
    for (String childConceptId : childConcepts) {
      // only add the child to the map if it exists in the corpus
      if (conceptICMap.containsKey(childConceptId)) {
        FeatureRank chd = mapChildConcept.get(childConceptId);
        // create the child if it does not already exist
        if (chd == null) {
          chd = new FeatureRank(fe, childConceptId, 0d);
          mapChildConcept.put(childConceptId, chd);
        }
        // give the child the mutual info of the parent with the highest
        // score
        double rawEvaluation = conceptRawEvalMap
            .containsKey(childConceptId) ? conceptRawEvalMap
            .get(childConceptId) : minInfo;
        double imputedEvaluation = (imputeWeight * parentConcept
            .getEvaluation())
            + ((1 - imputeWeight) * rawEvaluation);
        if (chd.getEvaluation() < imputedEvaluation) {
          chd.setEvaluation(imputedEvaluation);
        }
        // add the relationship to the parentChildMap
        // do this only if the childParentMap is not null
        if (childParentMap != null) {
          Set<FeatureRank> parents = childParentMap.get(chd);
          if (parents == null) {
            parents = new HashSet<FeatureRank>(10);
            childParentMap.put(chd, parents);
          }
          parents.add(parentConcept);
        }
      }
    }
  }
}
TOP

Related Classes of org.apache.ctakes.ytex.kernel.ImputedFeatureEvaluatorImpl$Parameters

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.