Package edu.msu.cme.rdp.classifier.train.validation.crossvalidate

Source Code of edu.msu.cme.rdp.classifier.train.validation.crossvalidate.CrossValidate

/*
* Copyright (C) 2012 Michigan State University <rdpstaff at msu.edu>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

package edu.msu.cme.rdp.classifier.train.validation.crossvalidate;

import edu.msu.cme.rdp.classifier.train.validation.StatusCount;
import edu.msu.cme.rdp.classifier.train.validation.ValidationClassificationResult;
import edu.msu.cme.rdp.classifier.train.validation.ValidClassificationResultFacade;
import edu.msu.cme.rdp.classifier.train.validation.DecisionMaker;
import edu.msu.cme.rdp.classifier.train.LineageSequence;
import edu.msu.cme.rdp.classifier.train.LineageSequenceParser;
import edu.msu.cme.rdp.classifier.train.validation.HierarchyTree;
import edu.msu.cme.rdp.classifier.train.validation.Taxonomy;
import edu.msu.cme.rdp.classifier.train.validation.TreeFactory;
import edu.msu.cme.rdp.readseq.utils.ResampleSeqFile;
import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
*
* @author wangqion
*/
  public class CrossValidate {

   
    /**
     * The method randomly selects a fraction of sequences from the source file as test set,
     * used the remaining sequences from the source file as training set.
     * @param tax_file
     * @param source_file
     * @param selectedTestSeqIDs
     * @throws IOException
     */
    public ArrayList<HashMap> runTest(File tax_file, File source_file, File out_file,
            String rdmSelectedRank, float fraction, Integer partialLength, boolean useSeed, int min_bootstrap_words) throws IOException{
        BufferedWriter outWriter = new BufferedWriter(new FileWriter(out_file));
       
        Set<String> selectedTestSeqIDs = null;
        if ( rdmSelectedRank == null){
            selectedTestSeqIDs = ResampleSeqFile.randomSelectSeq(source_file, fraction);
        }else {
            selectedTestSeqIDs = RdmSelectTaxon.randomSelectTaxon(tax_file, source_file, fraction, rdmSelectedRank);
        }
       
        TreeFactory factory = setup(tax_file, source_file, selectedTestSeqIDs );
        DecisionMaker dm = new DecisionMaker(factory);
         // get all the genus node list
        HashMap<String, HierarchyTree> genusNodeMap = new HashMap<String, HierarchyTree>();
        factory.getRoot().getNodeMap(Taxonomy.GENUS, genusNodeMap);
        if (genusNodeMap.isEmpty()) {
          throw new IllegalArgumentException("\nThere is no node in GENUS level!");
        }
       
        HashMap<String,HashSet> rankNodeMap = new HashMap<String,HashSet>();
        for (String rank: factory.getRankSet()){
            ArrayList<HierarchyTree> nodeList = new ArrayList<HierarchyTree>();
            factory.getRoot().getNodeList(rank, nodeList);
            HashSet<String> nodeNameSet = getnodeNameSet(nodeList);
            rankNodeMap.put(rank, nodeNameSet);
        }

        ArrayList<HashMap> statusCountList = new ArrayList<HashMap>();
        // initialize a list of statusCount, one for each bootstrap from 0 to 100
        for ( int b = 0; b <= 100; b++){
             HashMap<String, StatusCount> statusCountMap = new HashMap<String, StatusCount>();
             statusCountList.add(statusCountMap);
             for (String rank: factory.getRankSet()){
                 statusCountMap.put(rank, new StatusCount());
             }
         }

        int totalTest = 0;
        int totalSeq = 0;
        LineageSequenceParser parser = new LineageSequenceParser(source_file);
        while ( parser.hasNext()){
            totalSeq ++;
          LineageSequence pSeq = parser.next();
          if ( !selectedTestSeqIDs.contains(pSeq.getSeqName()) || pSeq.getSeqString().length() == 0){
              continue;
          }
          GoodWordIterator wordIterator = null ;
          if ( partialLength != null ){
                wordIterator = pSeq.getPartialSeqIteratorbyGoodBases(partialLength.intValue())// test partial sequences with good words only

          }else {
                wordIterator = new GoodWordIterator(pSeq.getSeqString()); // full sequence
          }

          if (wordIterator == null || wordIterator.getNumofWords() == 0){
            //System.err.println(pSeq.getSeqName() + " unable to find good sequence");
            continue;
          }
       
          List result = dm.getBestClasspath( wordIterator, genusNodeMap, useSeed, min_bootstrap_words);

          //xxx
          ValidClassificationResultFacade resultFacade = new ValidClassificationResultFacade(pSeq, result);

          compareClassificationResult(factory, resultFacade, rankNodeMap, statusCountList);

          totalTest++;
        }
        parser.close();

        outWriter.write("taxon file\t" + tax_file.getName() + "\n" + "train sequence file\t" + source_file.getName() + "\n");
        outWriter.write("word size\t" + GoodWordIterator.getWordsize() + "\n");
        outWriter.write("minimum number of words for bootstrap\t" + min_bootstrap_words + "\n");
        if ( partialLength != null){
            outWriter.write("length\t" + partialLength + "\n");
        }else {
            outWriter.write("length\t" +"full"+ "\n");
        }
        if ( rdmSelectedRank == null){
            outWriter.write("selectedRank\t" + "NA"+ "\n");
        }else {
            outWriter.write("selectedRank\t" + rdmSelectedRank+ "\n");
        }
       
        outWriter.write("trainingset size\t" + (totalSeq - selectedTestSeqIDs.size()) + "\n");
        outWriter.write("testset size\t" + totalTest + "\n");
       
        outWriter.write(calErrorRate(statusCountList));
        outWriter.close();
        return statusCountList;

    }

    /**
     * use the sequences not in the test set as training set
     * @param tax_file
     * @param source_file
     * @param selectedTestSeqIDs
     * @return
     * @throws IOException
     */
    private TreeFactory setup(File tax_file, File source_file, Set<String> selectedTestSeqIDs) throws IOException {

        TreeFactory factory = new TreeFactory(new FileReader(tax_file));
        LineageSequenceParser parser = new LineageSequenceParser(source_file);

        while ( parser.hasNext() ){
            LineageSequence pSeq = parser.next();
            if ( !selectedTestSeqIDs.contains(pSeq.getSeqName())){
              factory.addSequence( pSeq);
            }
        }
        parser.close();

        //after all the training set is being parsed, calculate the prior probability for all the words.
        factory.calculateWordPrior();
        return factory;
      }
   
    private HashSet<String> getnodeNameSet(ArrayList<HierarchyTree> genusNodeList){
        HashSet<String> nodeNameSet = new HashSet<String>();
        for (HierarchyTree t: genusNodeList ){
            nodeNameSet.add(t.getName());
        }
        return nodeNameSet;
    }

    /** If we only care about if the genus in the training set or not, there are four status
     * TP: bootstrap above cutoff, labeled taxon in training set
     * FN: bootstrap below cutoff, labeled taxon in training set
     * FP: bootstrap above cutoff, labeled taxon NOT in training set
     * TN: bootstrap below cutoff, labeled taxon NOT in training set
     *
     **/
     private void compareClassificationResult( TreeFactory factory, ValidClassificationResultFacade resultFacade,
       HashMap<String,HashSet> rankNodeMap, ArrayList<HashMap> statusCountList ) throws IOException{
         // determine assignment status
        // find all the taxa for the ancestors
        HashMap<String, Taxonomy> labeledTaxonMap = new HashMap<String, Taxonomy>();
        labeledTaxonMap.put(factory.getRoot().getTaxonomy().getHierLevel(), factory.getRoot().getTaxonomy());
        int pid = factory.getRoot().getTaxonomy().getTaxID();
        for ( int i = 1 ; i < resultFacade.getAncestors().size(); i++){
            Taxonomy tax = factory.getTaxonomy(resultFacade.getSeqName(), (String) resultFacade.getAncestors().get(i), pid, i);
            labeledTaxonMap.put(tax.getHierLevel(), tax);
            pid = tax.getTaxID();
        }


        List<ValidationClassificationResult> hitList = resultFacade.getRankAssignment();
        for ( ValidationClassificationResult curRankResult: hitList){
            String curRank = curRankResult.getBestClass().getTaxonomy().getHierLevel();
            // find the corresponding ancestor at the current rank
            Taxonomy matchingRankTaxon = labeledTaxonMap.get(curRank);
            if ( matchingRankTaxon == null) {  // no match rank found
                // System.err.println("no matching rank labeled taxon found for " + resultFacade.getSeqName() + " at " + curRank );
                continue;
            }
            HashSet<String> nodeNameSet = rankNodeMap.get(curRank);
            if ( nodeNameSet != null){
                  
                int bootstrap = (int)(curRankResult.getNumOfVotes()*100);
                //System.err.println( "rank: " + curRank + "\t" + matchingRankTaxon.getName() + "\t" + nodeNameSet.contains( matchingRankTaxon.getName()) + "\t" + bootstrap +"\t");
                if ( nodeNameSet.contains( matchingRankTaxon.getName())){  // TP or FN
                    for( int b = 0; b <= bootstrap; b++){
                        ((StatusCount)statusCountList.get(b).get(curRank)).incNumTP(1);
                    }
                    for( int b = bootstrap+1; b < statusCountList.size(); b++){
                        ((StatusCount)statusCountList.get(b).get(curRank)).incNumFN(1);
                    }

                }else {// TN or FP
                    for( int b = 0; b <= bootstrap; b++){
                        ((StatusCount)statusCountList.get(b).get(curRank)).incNumFP(1);
                    }
                    for( int b = bootstrap+1; b < statusCountList.size(); b++){
                        ((StatusCount)statusCountList.get(b).get(curRank)).incNumTN(1);
                    }
                 }

            }
        }

       
     }

     /**
      * sensitivity = #TP / (#TP + #FN)
      * specificity = #TN / (#TN + #FP)
      * @param statusCountList
      */
     public String calErrorRate(ArrayList<HashMap> statusCountList){
        StringBuilder ret = new StringBuilder();
        ret.append("\nbootstrap\t1-Specificity\tSensitivity\n");
        ret.append("bootstrap");
        HashMap<String, StatusCount> statusCountMap = statusCountList.get(0);
        for ( String rank: statusCountMap.keySet()){
            if ( rank.startsWith("sub")) continue;
            ret.append("\t" + rank + "_Spec" +"\t" + rank + "_Sens" );
        }
        ret.append("\n");

        for( int b = 0; b < statusCountList.size(); b++){
            ret.append(b );
            statusCountMap = statusCountList.get(b);
            for ( String rank: statusCountMap.keySet()){
                if ( rank.startsWith("sub")) continue;
                StatusCount st = statusCountMap.get(rank);
                double se = st.calSensitivity();
                double sp = st.calSpecificity();
                ret.append("\t" + (1-sp) + "\t" + se  );
            }
            ret.append("\n");
        }
        return ret.toString();
     }

}
TOP

Related Classes of edu.msu.cme.rdp.classifier.train.validation.crossvalidate.CrossValidate

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.