Package org.apache.ctakes.ytex.kernel

Source Code of org.apache.ctakes.ytex.kernel.KernelUtilImpl

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied.  See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.ytex.kernel;

import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.InvalidPropertiesFormatException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

import javax.sql.DataSource;

import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao;
import org.apache.ctakes.ytex.kernel.dao.KernelEvaluationDao;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.KernelEvaluation;
import org.apache.ctakes.ytex.kernel.model.KernelEvaluationInstance;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;


public class KernelUtilImpl implements KernelUtil {
  private static final Log log = LogFactory.getLog(KernelUtilImpl.class);
  private ClassifierEvaluationDao classifierEvaluationDao;

  private JdbcTemplate jdbcTemplate = null;

  private KernelEvaluationDao kernelEvaluationDao = null;
  private PlatformTransactionManager transactionManager;
  private FoldGenerator foldGenerator = null;

  public FoldGenerator getFoldGenerator() {
    return foldGenerator;
  }

  public void setFoldGenerator(FoldGenerator foldGenerator) {
    this.foldGenerator = foldGenerator;
  }

  private Map<Long, Integer> createInstanceIdToIndexMap(
      SortedSet<Long> instanceIDs) {
    Map<Long, Integer> instanceIdToIndexMap = new HashMap<Long, Integer>(
        instanceIDs.size());
    int i = 0;
    for (Long instanceId : instanceIDs) {
      instanceIdToIndexMap.put(instanceId, i);
      i++;
    }
    return instanceIdToIndexMap;
  }

  @Override
  public void fillGramMatrix(final KernelEvaluation kernelEvaluation,
      final SortedSet<Long> trainInstanceLabelMap,
      final double[][] trainGramMatrix) {
    // final Set<String> kernelEvaluationNames = new HashSet<String>(1);
    // kernelEvaluationNames.add(name);
    // prepare map of instance id to gram matrix index
    final Map<Long, Integer> trainInstanceToIndexMap = createInstanceIdToIndexMap(trainInstanceLabelMap);

    // iterate through the training instances
    for (Map.Entry<Long, Integer> instanceIdIndex : trainInstanceToIndexMap
        .entrySet()) {
      // index of this instance
      final int indexThis = instanceIdIndex.getValue();
      // id of this instance
      final long instanceId = instanceIdIndex.getKey();
      // get all kernel evaluations for this instance in a new transaction
      // don't want too many objects in hibernate session
      TransactionTemplate t = new TransactionTemplate(
          this.transactionManager);
      t.setPropagationBehavior(TransactionTemplate.PROPAGATION_REQUIRES_NEW);
      t.execute(new TransactionCallback<Object>() {
        @Override
        public Object doInTransaction(TransactionStatus arg0) {
          List<KernelEvaluationInstance> kevals = getKernelEvaluationDao()
              .getAllKernelEvaluationsForInstance(
                  kernelEvaluation, instanceId);
          for (KernelEvaluationInstance keval : kevals) {
            // determine the index of the instance
            Integer indexOtherTrain = null;
            long instanceIdOther = instanceId != keval
                .getInstanceId1() ? keval.getInstanceId1()
                : keval.getInstanceId2();
            // look in training set for the instance id
            indexOtherTrain = trainInstanceToIndexMap
                .get(instanceIdOther);
            if (indexOtherTrain != null) {
              trainGramMatrix[indexThis][indexOtherTrain] = keval
                  .getSimilarity();
              trainGramMatrix[indexOtherTrain][indexThis] = keval
                  .getSimilarity();
            }
          }
          return null;
        }
      });
    }
    // put 1's in the diagonal of the training gram matrix
    for (int i = 0; i < trainGramMatrix.length; i++) {
      if (trainGramMatrix[i][i] == 0)
        trainGramMatrix[i][i] = 1;
    }
  }

  public ClassifierEvaluationDao getClassifierEvaluationDao() {
    return classifierEvaluationDao;
  }

  public DataSource getDataSource() {
    return jdbcTemplate.getDataSource();
  }

  public KernelEvaluationDao getKernelEvaluationDao() {
    return kernelEvaluationDao;
  }

  public PlatformTransactionManager getTransactionManager() {
    return transactionManager;
  }

  @Override
  public double[][] loadGramMatrix(SortedSet<Long> instanceIds, String name,
      String splitName, String experiment, String label, int run,
      int fold, double param1, String param2) {
    int foldId = 0;
    double[][] gramMatrix = null;
    if (run != 0 && fold != 0) {
      CrossValidationFold f = this.classifierEvaluationDao
          .getCrossValidationFold(name, splitName, label, run, fold);
      if (f != null)
        foldId = f.getCrossValidationFoldId();
    }
    KernelEvaluation kernelEval = this.kernelEvaluationDao.getKernelEval(
        name, experiment, label, foldId, param1, param2);
    if (kernelEval == null) {
      log.warn("could not find kernelEvaluation.  name=" + name
          + ", experiment=" + experiment + ", label=" + label
          + ", fold=" + fold + ", run=" + run);
    } else {
      gramMatrix = new double[instanceIds.size()][instanceIds.size()];
      fillGramMatrix(kernelEval, instanceIds, gramMatrix);
    }
    return gramMatrix;
  }

  /**
   * this can be very large - avoid loading the entire jdbc ResultSet into
   * memory
   */
  @Override
  public InstanceData loadInstances(String strQuery) {
    final InstanceData instanceLabel = new InstanceData();
    PreparedStatement s = null;
    Connection conn = null;
    ResultSet rs = null;
    try {
      // jdbcTemplate.query(strQuery, new RowCallbackHandler() {
      RowCallbackHandler ch = new RowCallbackHandler() {

        @Override
        public void processRow(ResultSet rs) throws SQLException {
          String label = "";
          int run = 0;
          int fold = 0;
          boolean train = true;
          long instanceId = rs.getLong(1);
          String className = rs.getString(2);
          if (rs.getMetaData().getColumnCount() >= 3)
            train = rs.getBoolean(3);
          if (rs.getMetaData().getColumnCount() >= 4) {
            label = rs.getString(4);
            if (label == null)
              label = "";
          }
          if (rs.getMetaData().getColumnCount() >= 5)
            fold = rs.getInt(5);
          if (rs.getMetaData().getColumnCount() >= 6)
            run = rs.getInt(6);
          // get runs for label
          SortedMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>> runToInstanceMap = instanceLabel
              .getLabelToInstanceMap().get(label);
          if (runToInstanceMap == null) {
            runToInstanceMap = new TreeMap<Integer, SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>>();
            instanceLabel.getLabelToInstanceMap().put(label,
                runToInstanceMap);
          }
          // get folds for run
          SortedMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>> foldToInstanceMap = runToInstanceMap
              .get(run);
          if (foldToInstanceMap == null) {
            foldToInstanceMap = new TreeMap<Integer, SortedMap<Boolean, SortedMap<Long, String>>>();
            runToInstanceMap.put(run, foldToInstanceMap);
          }
          // get train/test set for fold
          SortedMap<Boolean, SortedMap<Long, String>> ttToClassMap = foldToInstanceMap
              .get(fold);
          if (ttToClassMap == null) {
            ttToClassMap = new TreeMap<Boolean, SortedMap<Long, String>>();
            foldToInstanceMap.put(fold, ttToClassMap);
          }
          // get instances for train/test set
          SortedMap<Long, String> instanceToClassMap = ttToClassMap
              .get(train);
          if (instanceToClassMap == null) {
            instanceToClassMap = new TreeMap<Long, String>();
            ttToClassMap.put(train, instanceToClassMap);
          }
          // set the instance class
          instanceToClassMap.put(instanceId, className);
          // add the class to the labelToClassMap
          SortedSet<String> labelClasses = instanceLabel
              .getLabelToClassMap().get(label);
          if (labelClasses == null) {
            labelClasses = new TreeSet<String>();
            instanceLabel.getLabelToClassMap().put(label,
                labelClasses);
          }
          if (!labelClasses.contains(className))
            labelClasses.add(className);
        }
      };
      conn = this.jdbcTemplate.getDataSource().getConnection();
      s = conn.prepareStatement(strQuery,
          java.sql.ResultSet.TYPE_FORWARD_ONLY,
          java.sql.ResultSet.CONCUR_READ_ONLY);
      if ("MySQL".equals(conn.getMetaData().getDatabaseProductName())) {
        s.setFetchSize(Integer.MIN_VALUE);
      } else if (s.getClass().getName()
          .equals("com.microsoft.sqlserver.jdbc.SQLServerStatement")) {
        try {
          BeanUtils.setProperty(s, "responseBuffering", "adaptive");
        } catch (IllegalAccessException e) {
          log.warn("error setting responseBuffering", e);
        } catch (InvocationTargetException e) {
          log.warn("error setting responseBuffering", e);
        }
      }
      rs = s.executeQuery();
      while (rs.next()) {
        ch.processRow(rs);
      }
    } catch (SQLException j) {
      log.error("loadInstances failed", j);
      throw new RuntimeException(j);
    } finally {
      if (rs != null) {
        try {
          rs.close();
        } catch (SQLException e) {
        }
      }
      if (s != null) {
        try {
          s.close();
        } catch (SQLException e) {
        }
      }
      if (conn != null) {
        try {
          conn.close();
        } catch (SQLException e) {
        }
      }
    }
    return instanceLabel;
  }

  /*
   * (non-Javadoc)
   *
   * @see org.apache.ctakes.ytex.kernel.DataExporter#loadProperties(java.lang.String,
   * java.util.Properties)
   */
  @Override
  public void loadProperties(String propertyFile, Properties props)
      throws FileNotFoundException, IOException,
      InvalidPropertiesFormatException {
    InputStream in = null;
    try {
      in = new FileInputStream(propertyFile);
      if (propertyFile.endsWith(".xml"))
        props.loadFromXML(in);
      else
        props.load(in);
    } finally {
      if (in != null) {
        in.close();
      }
    }
  }

  public void setClassifierEvaluationDao(
      ClassifierEvaluationDao classifierEvaluationDao) {
    this.classifierEvaluationDao = classifierEvaluationDao;
  }

  public void setDataSource(DataSource dataSource) {
    this.jdbcTemplate = new JdbcTemplate(dataSource);
  }

  public void setKernelEvaluationDao(KernelEvaluationDao kernelEvaluationDao) {
    this.kernelEvaluationDao = kernelEvaluationDao;
  }

  public void setTransactionManager(
      PlatformTransactionManager transactionManager) {
    this.transactionManager = transactionManager;
  }

  @Override
  public void generateFolds(InstanceData instanceLabel, Properties props) {
    int folds = Integer.parseInt(props.getProperty("folds"));
    int runs = Integer.parseInt(props.getProperty("runs", "1"));
    int minPerClass = Integer.parseInt(props
        .getProperty("minPerClass", "0"));
    Integer randomNumberSeed = props.containsKey("rand") ? Integer
        .parseInt(props.getProperty("rand")) : null;
    instanceLabel.setLabelToInstanceMap(foldGenerator.generateRuns(
        instanceLabel.getLabelToInstanceMap(), folds, minPerClass,
        randomNumberSeed, runs));
  }

  /**
   * assign numeric indices to string class names
   *
   * @param labelToClasMap
   * @param labelToClassIndexMap
   */
  @Override
  public void fillLabelToClassToIndexMap(
      Map<String, SortedSet<String>> labelToClasMap,
      Map<String, BiMap<String, Integer>> labelToClassIndexMap) {
    for (Map.Entry<String, SortedSet<String>> labelToClass : labelToClasMap
        .entrySet()) {
      BiMap<String, Integer> classToIndexMap = HashBiMap.create();
      labelToClassIndexMap.put(labelToClass.getKey(), classToIndexMap);
      int nIndex = 1;
      for (String className : labelToClass.getValue()) {
        Integer classNumber = null;
        try {
          classNumber = Integer.parseInt(className);
        } catch (NumberFormatException fe) {
        }
        if (classNumber == null) {
          classToIndexMap.put(className, nIndex++);
        } else {
          classToIndexMap.put(className, classNumber);
        }
      }
    }
  }

  /**
   * export the class id to class name map.
   *
   * @param classIdMap
   * @param label
   * @param run
   * @param fold
   * @throws IOException
   */
  public void exportClassIds(String outdir, Map<String, Integer> classIdMap,
      String label) throws IOException {
    // construct file name
    String filename = FileUtil.getScopedFileName(outdir, label, null, null,
        "class.properties");
    Properties props = new Properties();
    for (Map.Entry<String, Integer> entry : classIdMap.entrySet()) {
      props.put(entry.getValue().toString(), entry.getKey());
    }
    BufferedWriter w = null;
    try {
      w = new BufferedWriter(new FileWriter(filename));
      props.store(w, "class id to class name map");
    } finally {
      if (w != null) {
        w.close();
      }
    }
  }
}
TOP

Related Classes of org.apache.ctakes.ytex.kernel.KernelUtilImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.