Package edu.msu.cme.rdp.classifier.train.validation.movingwindow

Source Code of edu.msu.cme.rdp.classifier.train.validation.movingwindow.WindowTester

/*
* Tester.java
*
* Created on June 24, 2002, 6:26 PM
*/
/**
*
* @author  wangqion
* @version
*/
package edu.msu.cme.rdp.classifier.train.validation.movingwindow;

import java.util.*;
import java.io.*;
import java.text.*;

import edu.msu.cme.rdp.classifier.train.validation.ValidationClassificationResult;
import edu.msu.cme.rdp.classifier.train.validation.ValidClassificationResultFacade;
import edu.msu.cme.rdp.classifier.train.validation.CorrectAssignment;
import edu.msu.cme.rdp.classifier.train.validation.DecisionMaker;
import edu.msu.cme.rdp.classifier.train.LineageSequence;
import edu.msu.cme.rdp.classifier.train.validation.HierarchyTree;
import edu.msu.cme.rdp.classifier.train.validation.Taxonomy;
import edu.msu.cme.rdp.classifier.train.validation.TreeFactory;
import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator;

/** */
public class WindowTester {

    BufferedWriter outFile;
    private boolean bootstrap = true;
    private int windowIndex = 0;
    private String testRank = Taxonomy.GENUS;
    private Map num_hierLevel = new HashMap()// key is the hierarchy level, value is
    // he number of correctly classified sequences at that level
    List missSeqList = new ArrayList()// this is only for test purpose, if the sequence

    /** Creates new Tester */
    public WindowTester(Writer writer) throws IOException {
        outFile = (BufferedWriter) writer;
    }

    /** classify each sequence from a parser */
    public void classify(TreeFactory factory, ArrayList seqList, Window window, int windowIndex, int min_bootstrap_words)
            throws IOException {
        this.windowIndex = windowIndex;

        //for each sequence with name, and or true path
        List resultList = new ArrayList();

        DecisionMaker dm = new DecisionMaker(factory);
        HierarchyTree root = factory.getRoot();

        HashMap<String, HierarchyTree> genusNodeMap = new HashMap<String, HierarchyTree>();
        factory.getRoot().getNodeMap(testRank, genusNodeMap);

        if (genusNodeMap.isEmpty()) {
            throw new IllegalArgumentException("\nThere is no node in GENUS level!");
        }

        int i = 0;
        Iterator seqIt = seqList.iterator();
        while (seqIt.hasNext()) {
            LineageSequence pSeq = (LineageSequence) seqIt.next();

            GoodWordIterator wordIterator = getPartialSeqIteratorbyWindow(pSeq, window); // full sequence 

            if (wordIterator == null) {
                continue;
            }

            //for leave-one-out testing, we need to remove the word occurrance for
            //the current sequence. This is similiar to hiding a sequence leaf.
            HierarchyTree curTree = genusNodeMap.get((String) pSeq.getAncestors().get(pSeq.getAncestors().size() - 1));


            curTree.hideSeq(wordIterator);
            List result = dm.getBestClasspath( wordIterator, genusNodeMap, false, min_bootstrap_words);

            ValidClassificationResultFacade resultFacade = new ValidClassificationResultFacade(pSeq, result);
            resultFacade.setLabeledNode(curTree);
            compareClassificationResult(resultFacade);

            resultList.add(resultFacade);
            i++;
            // recover the wordOccurrence of the genus node, unhide
            curTree.unhideSeq(wordIterator);

        }

        displayStat();
    }

    private GoodWordIterator getPartialSeqIteratorbyWindow(LineageSequence pSeq, Window w) throws IOException {
        int firstbase = findFirstBase(pSeq.getSeqString());
        int lastbase = findLastBase(pSeq.getSeqString());

        if (firstbase > w.getStart() || lastbase < w.getStop()) {
            return null;
        }


        int stop = w.getStop();
        if (stop > pSeq.getSeqString().length()) {
            stop = pSeq.getSeqString().length();
        }
        String seqString = pSeq.getSeqString().substring(w.getStart() - 1, stop);
        seqString = seqString.replaceAll("-", "");
        // at least half of the window size.

        if (seqString.length() < FindWindowFrame.window_size / 2) {
            return null;
        }
        GoodWordIterator wordIterator = new GoodWordIterator(seqString);
        if (wordIterator.getNumofWords() == 0) {
            wordIterator = null;
        }
        return wordIterator;
    }

    private int findFirstBase(String s) throws IOException {
        StringReader reader = new StringReader(s);
        int c;
        int index = 1;
        while ((c = reader.read()) != -1) {
            char ch = (char) c;
            if (ch != '-') {
                return index;
            }
            index++;
        }
        reader.close();
        return -1;
    }

    private int findLastBase(String s) throws IOException {
        StringReader reader = new StringReader(s);
        int c;
        int index = 1;
        int lastBase = 0;
        while ((c = reader.read()) != -1) {
            char ch = (char) c;
            if (ch != '-') {
                lastBase = index;
            }
            index++;
        }
        reader.close();
        return lastBase;
    }

    /** get all the lowest level nodes in given hierarchy level starting from the given root
     *
    public void getNodeList(HierarchyTree root, String level, List nodeList) {
        if (root == null) {
            return;
        }

        if (((Taxonomy) root.getTaxonomy()).getHierLevel().equalsIgnoreCase(level)) {
            nodeList.add(root);
            return;
        }
        //start from the root of the tree, get the subclasses.
        Collection al = new ArrayList();

        if ((al = root.getSubclasses()).isEmpty()) {
            return;
        }
        Iterator i = al.iterator();
        while (i.hasNext()) {
            getNodeList((HierarchyTree) i.next(), level, nodeList);
        }
    } */

    /** search the tree node that the current test getSequence() belongs to
     */
    /*
    public HierarchyTree getTreeNode(LineageSequence pSeq, List nodeList) {
        Iterator i = nodeList.iterator();
        String genusNode = (String) pSeq.getAncestors().get(pSeq.getAncestors().size() - 1);

        while (i.hasNext()) {
            HierarchyTree tree = (HierarchyTree) i.next();
            if ((tree.getName()).equals(genusNode)) {
                return tree;
            }
        }

        throw new IllegalArgumentException("Test getSequence() name: "
                + pSeq.getSeqName() + " is not found in the tree. Can not continue testing");

    } */

    public void displayTreeErrorRate(HierarchyTree root, int indent) throws IOException {
        int k = 0;
        while (k < indent) {
            outFile.write("  ");
            k++;
        }
        outFile.write(root.getName() + "\t" + root.getTotalSeqs() + "\t" + root.getMissCount() + "\n");
        Iterator i = root.getSubclasses().iterator();

        while (i.hasNext()) {
            displayTreeErrorRate((HierarchyTree) i.next(), indent + 1);
        }
    }

    private void displayResult(List seqs, List aPath) throws IOException {
        int i = 0;
        int size1 = seqs.size();

        for (i = 0; i < size1; i++) {
            outFile.write((String) seqs.get(i) + "*");
            int size = ((List) aPath.get(i)).size();
            if (size == 0) {
                outFile.write("\n");
                continue;
            }
            ValidationClassificationResult result = (ValidationClassificationResult) ((List) aPath.get(i)).get(size - 1);

            outFile.write(((Taxonomy) ((HierarchyTree) result.getBestClass()).getTaxonomy()).getTaxID() + "\n");
        }
    }

    /** Displays the results of the classification, true path, assigned path and the
     * error rate
     */
    private void displayClassification(List seqs)
            throws IOException {
        int i = 0;

        outFile.write("\n missclassified getSequence()s: \n");
        for (i = 0; i < seqs.size(); i++) {
            ValidClassificationResultFacade resultFacade = (ValidClassificationResultFacade) seqs.get(i);
            if (resultFacade.isMissed()) {
                printPath(resultFacade, outFile);
            }
        }

        DecimalFormat df = new DecimalFormat("0.###");

        outFile.write("\n\n**The statistics for each hierarchy level: \n 1: number of correct assigned getSequence()s / number of total getSequence()s for that bin rage\n");
        outFile.write("Level\t100-95\t\t94-90");
        for (i = 9; i > 0; i--) {
            outFile.write("\t\t" + (i * 10 - 1) + "-" + (i - 1) * 10);
        }
        outFile.write("\n");

        System.err.print("\t");
        Iterator it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();
            outFile.write(name + "\t");
            if (this.windowIndex == MainMovingWindow.getBeginIndex() && !name.equals("domain") && !name.equals("norank") && !name.startsWith("sub")) {
                System.err.print(name + "\t");
            }
            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            calStandardError(assign);

            for (i = assign.bins - 1; i >= 0; i--) {
                outFile.write(assign.numCorrect[i] + "\t" + assign.numTotal[i] + "\t");
            }
            outFile.write("\n");
        }

        outFile.write("\n\n** 2. The average votes for each bin range \n");
        it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();
            outFile.write(name + " \t ");
            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            for (i = assign.bins - 1; i >= 0; i--) {
                if (assign.numTotal[i] == 0) {
                    outFile.write("0\t");
                } else {
                    outFile.write(df.format(assign.sumOfVotes[i] * 100 / assign.numTotal[i]) + "\t");
                }
            }
            outFile.write(" \n");
        }

        outFile.write("\n\n** 3. The percentage of correctness for each bin range (the percentage of #1)\n");
        it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();
            outFile.write(name + " \t ");
            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            for (i = assign.bins - 1; i >= 0; i--) {
                if (assign.numTotal[i] == 0) {
                    outFile.write("0\t");
                } else {
                    outFile.write(df.format((float) assign.numCorrect[i] / (float) assign.numTotal[i]) + "\t");
                }
            }
            outFile.write(" \n");
        }

        outFile.write("\n\n** 4. The standard error for each bin range \n");

        it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();
            outFile.write(name + " \t ");
            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            for (i = assign.bins - 1; i >= 0; i--) {
                outFile.write(df.format(assign.standardError[i]) + "\t");
            }
            outFile.write(" \n");
        }
    }

    private void displayStat() throws IOException {
        outFile.write("\t");
        Iterator it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();
            if (this.windowIndex == MainMovingWindow.getBeginIndex() && !name.equals("domain") && !name.equals("norank") && !name.startsWith("sub")) {
                outFile.write(name + "\t");
            }
            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            calStandardError(assign);

        }

        outFile.write("\n");
        // print to std err
        outFile.write("V\t" + this.windowIndex);
        it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();

            if (name.equals("domain") || name.equals("norank") || name.startsWith("sub")) {
                continue;
            }

            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);
            outFile.write("\t" + getAvgVotes(assign));

        }


        outFile.write("\nC\t" + this.windowIndex);
        it = num_hierLevel.keySet().iterator();
        while (it.hasNext()) {
            String name = (String) it.next();

            if (name.equals("domain") || name.equals("norank") || name.startsWith("sub")) {
                continue;
            }

            CorrectAssignment assign = (CorrectAssignment) num_hierLevel.get(name);

            outFile.write("\t" + getCorrectRate(assign));

        }

    }

    /** print the true path and the assigned path
     * Note: the true path is list of string of getAncestors() for a getSequence(),
     * the assigned path is a list of ClassificationResult for a getSequence().
     */
    private void printPath(ValidClassificationResultFacade resultFacade, BufferedWriter out)
            throws IOException {
        out.write("SEQ: " + resultFacade.getSeqName() + "\n");
        Iterator i = resultFacade.getAncestors().iterator();
        while (i.hasNext()) {
            out.write(i.next() + "\t");
        }
        out.write("\n");

        i = resultFacade.getRankAssignment().iterator();
        while (i.hasNext()) {
            ValidationClassificationResult result = (ValidationClassificationResult) i.next();
            out.write(((HierarchyTree) result.getBestClass()).getName() + "\t");

        }

        out.write("\n");
        i = resultFacade.getRankAssignment().iterator();
        while (i.hasNext()) {
            out.write(((ValidationClassificationResult) i.next()).getNumOfVotes() + "\t");
        }
        out.write("\n");
    }

    /** Compare the assigned path with the true path for the test getSequence(),
     * counts the number of correct classes and the number of getSequence()s for
     * each path level.
     */
    private void compareClassificationResult(ValidClassificationResultFacade resultFacade) {
        HierarchyTree trueParent = resultFacade.getLabeledNode();
        List hitList = resultFacade.getRankAssignment();

        // compare the true taxon and the hit taxon with same rank.
        while (trueParent != null) {

            if (!trueParent.isSingleton()) {
                ValidationClassificationResult hit = null;
                for (int i = 0; i < hitList.size(); i++) {
                    ValidationClassificationResult tmp = (ValidationClassificationResult) hitList.get(i);
                    if ((trueParent.getTaxonomy().getHierLevel()).equals(tmp.getBestClass().getTaxonomy().getHierLevel())) {
                        hit = tmp;
                        break;
                    }
                }

                if (trueParent.getTaxonomy().getHierLevel().equals(testRank)) {
                    trueParent.incNumTotalTestedseq();
                }
                boolean correct = false;


                if (hit != null && trueParent.getTaxonomy().getTaxID() == hit.getBestClass().getTaxonomy().getTaxID()) {
                    correct = true;
                } else {
                    if (trueParent.getTaxonomy().getHierLevel().equals(testRank)) {
                        trueParent.incMissCount();
                    }
                    resultFacade.setMissedRank(trueParent.getTaxonomy().getHierLevel());
                }
                if (hit != null) {
                    increaseCount(hit, trueParent.getTaxonomy().getHierLevel(), correct, num_hierLevel);
                }

            } else {
                // System.err.println(" singleton: " + trueParent.getName() + " " + trueParent.getTaxonomy().getHierLevel() );
            }

            trueParent = trueParent.getParent();

        }

    }

    private void increaseCount(ValidationClassificationResult aResult, String level, boolean correct, Map aMap) {
        // bin range: 100-95, 94-90, 89-80, 79-70.....9-0
        CorrectAssignment assign = (CorrectAssignment) aMap.get(level);
        if (assign == null) {
            assign = new CorrectAssignment();
            aMap.put(level, assign);
        }

        int binIndex = (int) Math.floor(aResult.getNumOfVotes() * 10);
        if (binIndex == 9) {
            if (aResult.getNumOfVotes() >= 0.95) {
                binIndex = 10;   // we put 100-95 to the bin #10
            }
        }
        assign.numTotal[binIndex]++;
        assign.sumOfVotes[binIndex] += aResult.getNumOfVotes();

        if (correct) {
            assign.numCorrect[binIndex]++;
        }

    }

    private void calStandardError(CorrectAssignment assign) {
        for (int i = 0; i < assign.bins; i++) {
            if (assign.numTotal[i] == 0) {
                assign.standardError[i] = (float) 0.0;
            } else {
                float p = (float) assign.numCorrect[i] / (float) assign.numTotal[i];
                assign.standardError[i] = (float) Math.sqrt(p * (1 - p) / assign.numTotal[i]);
            }

        }
    }

    private float getCorrectRate(CorrectAssignment assign) {
        int total = 0;
        int numCorrect = 0;
        for (int i = assign.bins - 1; i >= 0; i--) {
            total += assign.numTotal[i];
            numCorrect += assign.numCorrect[i];
        }

        float retval = 0;
        if (total > 0) {
            retval = (float) numCorrect / (float) total;
        }

        return retval;
    }

    private float getAvgVotes(CorrectAssignment assign) {
        int totalNumOfSeq = 0;
        float sumOfVotes = 0.0f;
        for (int i = 0; i < assign.bins; i++) {
            totalNumOfSeq += assign.numTotal[i];
            sumOfVotes += assign.sumOfVotes[i];
        }

        float avg = 0.0f;
        if (totalNumOfSeq > 0) {
            avg = sumOfVotes / (float) totalNumOfSeq;
        }

        return avg;
    }
}
TOP

Related Classes of edu.msu.cme.rdp.classifier.train.validation.movingwindow.WindowTester

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.