/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ctakes.ytex.kernel.dao;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ctakes.ytex.dao.DBUtil;
import org.apache.ctakes.ytex.kernel.InfoContentEvaluator;
import org.apache.ctakes.ytex.kernel.IntrinsicInfoContentEvaluator;
import org.apache.ctakes.ytex.kernel.metric.ConceptInfo;
import org.apache.ctakes.ytex.kernel.model.ClassifierEvaluation;
import org.apache.ctakes.ytex.kernel.model.ClassifierEvaluationIRStat;
import org.apache.ctakes.ytex.kernel.model.ClassifierInstanceEvaluation;
import org.apache.ctakes.ytex.kernel.model.CrossValidationFold;
import org.apache.ctakes.ytex.kernel.model.FeatureEvaluation;
import org.apache.ctakes.ytex.kernel.model.FeatureParentChild;
import org.apache.ctakes.ytex.kernel.model.FeatureRank;
import org.hibernate.Query;
import org.hibernate.SessionFactory;
import org.hibernate.type.Type;
public class ClassifierEvaluationDaoImpl implements ClassifierEvaluationDao {
private static final Log log = LogFactory
.getLog(ClassifierEvaluationDaoImpl.class);
private SessionFactory sessionFactory;
public SessionFactory getSessionFactory() {
return sessionFactory;
}
public void setSessionFactory(SessionFactory sessionFactory) {
this.sessionFactory = sessionFactory;
}
@SuppressWarnings("unchecked")
@Override
public void deleteCrossValidationFoldByName(String corpusName,
String splitName) {
Query q = this.getSessionFactory().getCurrentSession()
.getNamedQuery("getCrossValidationFoldByName");
q.setString("corpusName", corpusName);
q.setString("splitName", nullToEmptyString(splitName));
List<CrossValidationFold> folds = q.list();
for (CrossValidationFold fold : folds)
this.getSessionFactory().getCurrentSession().delete(fold);
}
@Override
public CrossValidationFold getCrossValidationFold(String corpusName,
String splitName, String label, int run, int fold) {
Query q = this.getSessionFactory().getCurrentSession()
.getNamedQuery("getCrossValidationFold");
q.setString("corpusName", corpusName);
q.setString("splitName", nullToEmptyString(splitName));
q.setString("label", nullToEmptyString(label));
q.setInteger("run", run);
q.setInteger("fold", fold);
return (CrossValidationFold) q.uniqueResult();
}
/*
* (non-Javadoc)
*
* @see
* org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao#saveClassifierEvaluation(org.apache.ctakes.ytex
* .kernel.model.ClassifierEvaluation)
*/
public void saveClassifierEvaluation(ClassifierEvaluation eval,
Map<Integer, String> irClassMap, boolean saveInstanceEval) {
saveClassifierEvaluation(eval, irClassMap, saveInstanceEval, true, null);
}
public void saveClassifierEvaluation(ClassifierEvaluation eval,
Map<Integer, String> irClassMap, boolean saveInstanceEval,
boolean saveIRStats, Integer excludeTargetClassId) {
this.getSessionFactory().getCurrentSession().save(eval);
if (saveIRStats)
this.saveIRStats(eval, irClassMap, excludeTargetClassId);
if (saveInstanceEval) {
for (ClassifierInstanceEvaluation instanceEval : eval
.getClassifierInstanceEvaluations().values()) {
this.getSessionFactory().getCurrentSession().save(instanceEval);
}
}
}
void saveIRStats(ClassifierEvaluation eval,
Map<Integer, String> irClassMap, Integer excludeTargetClassId) {
Set<Integer> classIds = this.getClassIds(eval, excludeTargetClassId);
// setup stats
for (Integer irClassId : classIds) {
String irClass = null;
if (irClassMap != null)
irClass = irClassMap.get(irClassId);
if (irClass == null)
irClass = Integer.toString(irClassId);
ClassifierEvaluationIRStat irStat = calcIRStats(irClass, irClassId,
eval, excludeTargetClassId);
this.getSessionFactory().getCurrentSession().save(irStat);
}
}
/**
*
* @param irClassId
* the target class id with respect to ir statistics will be
* calculated
* @param eval
* the object to update
* @param excludeTargetClassId
* class id to be excluded from computation of ir stats.
* @return
*/
private ClassifierEvaluationIRStat calcIRStats(String irClass,
Integer irClassId, ClassifierEvaluation eval,
Integer excludeTargetClassId) {
int tp = 0;
int tn = 0;
int fp = 0;
int fn = 0;
for (ClassifierInstanceEvaluation instanceEval : eval
.getClassifierInstanceEvaluations().values()) {
if (instanceEval.getTargetClassId() != null
&& (excludeTargetClassId == null || instanceEval
.getTargetClassId() != excludeTargetClassId
.intValue())) {
if (instanceEval.getTargetClassId() == irClassId) {
if (instanceEval.getPredictedClassId() == instanceEval
.getTargetClassId()) {
tp++;
} else {
fn++;
}
} else {
if (instanceEval.getPredictedClassId() == irClassId) {
fp++;
} else {
tn++;
}
}
}
}
return new ClassifierEvaluationIRStat(eval, null, irClass, irClassId,
tp, tn, fp, fn);
}
private Set<Integer> getClassIds(ClassifierEvaluation eval,
Integer excludeTargetClassId) {
Set<Integer> classIds = new HashSet<Integer>();
for (ClassifierInstanceEvaluation instanceEval : eval
.getClassifierInstanceEvaluations().values()) {
classIds.add(instanceEval.getPredictedClassId());
if (instanceEval.getTargetClassId() != null
&& (excludeTargetClassId == null || instanceEval
.getTargetClassId() != excludeTargetClassId
.intValue()))
classIds.add(instanceEval.getTargetClassId());
}
return classIds;
}
@Override
public void saveFold(CrossValidationFold fold) {
this.getSessionFactory().getCurrentSession().save(fold);
}
// @Override
// public void saveInfogain(List<FeatureInfogain> foldInfogainList) {
// for(FeatureInfogain ig : foldInfogainList) {
// this.getSessionFactory().getCurrentSession().save(ig);
// }
// }
@Override
public void saveFeatureEvaluation(FeatureEvaluation featureEvaluation,
List<FeatureRank> features) {
this.getSessionFactory().getCurrentSession().save(featureEvaluation);
for (FeatureRank r : features)
this.getSessionFactory().getCurrentSession().save(r);
}
@SuppressWarnings("unchecked")
@Override
public void deleteFeatureEvaluationByNameAndType(String corpusName,
String featureSetName, String type) {
Query q = this.getSessionFactory().getCurrentSession()
.getNamedQuery("getFeatureEvaluationByNameAndType");
q.setString("corpusName", corpusName);
q.setString("featureSetName", nullToEmptyString(featureSetName));
q.setString("type", type);
for (FeatureEvaluation fe : (List<FeatureEvaluation>) q.list())
this.getSessionFactory().getCurrentSession().delete(fe);
}
@SuppressWarnings("unchecked")
@Override
public List<FeatureRank> getTopFeatures(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, double param1, String param2,
Integer parentConceptTopThreshold) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2, "getTopFeatures");
q.setMaxResults(parentConceptTopThreshold);
return q.list();
}
@Override
public Double getMaxFeatureEvaluation(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, double param1, String param2) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2,
"getMaxFeatureEvaluation");
return (Double) q.uniqueResult();
}
private Query prepareUniqueFeatureEvalQuery(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, Double param1, String param2, String queryName) {
Query q = this.sessionFactory.getCurrentSession().getNamedQuery(
queryName);
q.setString("corpusName", nullToEmptyString(corpusName));
q.setString("featureSetName", nullToEmptyString(featureSetName));
q.setString("label", nullToEmptyString(label));
q.setString("evaluationType", evaluationType);
q.setDouble("param1", param1 == null ? 0 : param1);
q.setString("param2", nullToEmptyString(param2));
q.setInteger("crossValidationFoldId", foldId == null ? 0 : foldId);
return q;
}
/**
* todo for oracle need to handle empty strings differently
*
* @param param1
* @return
*/
private String nullToEmptyString(String param1) {
return DBUtil.nullToEmptyString(param1);
}
@SuppressWarnings("unchecked")
@Override
public List<FeatureRank> getThresholdFeatures(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, double param1, String param2,
double evaluationThreshold) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2,
"getThresholdFeatures");
q.setDouble("evaluation", evaluationThreshold);
return q.list();
}
@Override
public void deleteFeatureEvaluation(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, Double param1, String param2) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2,
"getFeatureEvaluationByNK");
FeatureEvaluation fe = (FeatureEvaluation) q.uniqueResult();
if (fe != null) {
// for some reason this isn't working - execute batch updates
// this.sessionFactory.getCurrentSession().delete(fe);
q = this.sessionFactory.getCurrentSession().getNamedQuery(
"deleteFeatureRank");
q.setInteger("featureEvaluationId", fe.getFeatureEvaluationId());
q.executeUpdate();
q = this.sessionFactory.getCurrentSession().getNamedQuery(
"deleteFeatureEval");
q.setInteger("featureEvaluationId", fe.getFeatureEvaluationId());
q.executeUpdate();
}
}
public Map<String, FeatureRank> getFeatureRanks(Set<String> featureNames,
String corpusName, String featureSetName, String label,
String evaluationType, Integer foldId, double param1, String param2) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2,
"getFeatureRankEvaluations");
q.setParameterList("featureNames", featureNames);
@SuppressWarnings("unchecked")
List<FeatureRank> featureRanks = q.list();
Map<String, FeatureRank> frMap = new HashMap<String, FeatureRank>(
featureRanks.size());
for (FeatureRank fr : featureRanks)
frMap.put(fr.getFeatureName(), fr);
return frMap;
}
public Map<String, Double> getFeatureRankEvaluations(
Set<String> featureNames, String corpusName, String featureSetName,
String label, String evaluationType, Integer foldId, double param1,
String param2) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2,
"getFeatureRankEvaluations");
q.setParameterList("featureNames", featureNames);
List<FeatureRank> featureRanks = q.list();
Map<String, Double> evalMap = new HashMap<String, Double>(
featureRanks.size());
for (FeatureRank fr : featureRanks)
evalMap.put(fr.getFeatureName(), fr.getEvaluation());
return evalMap;
}
@Override
public Map<String, Double> getFeatureRankEvaluations(String corpusName,
String featureSetName, String label, String evaluationType,
Integer foldId, double param1, String param2) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
label, evaluationType, foldId, param1, param2, "getTopFeatures");
@SuppressWarnings("unchecked")
List<FeatureRank> listFeatureRank = q.list();
Map<String, Double> mapFeatureEval = new HashMap<String, Double>(
listFeatureRank.size());
for (FeatureRank r : listFeatureRank) {
mapFeatureEval.put(r.getFeatureName(), r.getEvaluation());
}
return mapFeatureEval;
}
@Override
@SuppressWarnings("unchecked")
public List<Object[]> getCorpusCuiTuis(String corpusName,
String conceptGraphName, String conceptSetName) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, conceptSetName,
null, InfoContentEvaluator.INFOCONTENT, 0, 0d,
conceptGraphName, "getCorpusCuiTuis");
return q.list();
}
@Override
public Map<String, Double> getInfoContent(String corpusName,
String conceptGraphName, String conceptSet) {
return getFeatureRankEvaluations(corpusName, conceptSet, null,
InfoContentEvaluator.INFOCONTENT, 0, 0, conceptGraphName);
}
@Override
public List<ConceptInfo> getIntrinsicInfoContent(
String conceptGraphName) {
Query q = prepareUniqueFeatureEvalQuery(null, null, null,
IntrinsicInfoContentEvaluator.INTRINSIC_INFOCONTENT, null, null,
conceptGraphName, "getIntrinsicInfoContent");
return (List<ConceptInfo>)q.list();
}
public Integer getMaxDepth(String conceptGraphName) {
Query q = prepareUniqueFeatureEvalQuery(null, null, null,
IntrinsicInfoContentEvaluator.INTRINSIC_INFOCONTENT, null, null,
conceptGraphName, "getMaxFeatureRank");
return (Integer)q.uniqueResult();
}
@Override
public void saveFeatureParentChild(FeatureParentChild parchd) {
this.sessionFactory.getCurrentSession().save(parchd);
}
@Override
public List<FeatureRank> getImputedFeaturesByPropagatedCutoff(
String corpusName, String conceptSetName, String label,
String evaluationType, String conceptGraphName,
String propEvaluationType, int propRankCutoff) {
Query q = prepareUniqueFeatureEvalQuery(corpusName, conceptSetName,
label, evaluationType, 0, 0d, conceptGraphName,
"getImputedFeaturesByPropagatedCutoff");
q.setInteger("propRankCutoff", propRankCutoff);
q.setString("propEvaluationType", propEvaluationType);
return q.list();
}
}