Package edu.msu.cme.rdp.classifier.train.validation.distance

Source Code of edu.msu.cme.rdp.classifier.train.validation.distance.TaxaSimilarityMain

/*
* Copyright (C) 2014 wangqion
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/

package edu.msu.cme.rdp.classifier.train.validation.distance;

import edu.msu.cme.rdp.alignment.AlignmentMode;
import edu.msu.cme.rdp.alignment.pairwise.PairwiseAligner;
import edu.msu.cme.rdp.alignment.pairwise.PairwiseAlignment;
import edu.msu.cme.rdp.alignment.pairwise.ScoringMatrix;
import edu.msu.cme.rdp.alignment.pairwise.rna.DistanceModel;
import edu.msu.cme.rdp.alignment.pairwise.rna.IdentityDistanceModel;
import edu.msu.cme.rdp.alignment.pairwise.rna.OverlapCheckFailedException;
import edu.msu.cme.rdp.classifier.train.LineageSequence;
import edu.msu.cme.rdp.classifier.train.LineageSequenceParser;
import edu.msu.cme.rdp.classifier.train.validation.HierarchyTree;
import edu.msu.cme.rdp.classifier.train.validation.TreeFactory;
import edu.msu.cme.rdp.readseq.utils.kmermatch.KmerMatchCore;
import edu.msu.cme.rdp.readseq.utils.kmermatch.NuclSeqMatch;
import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator;
import java.awt.BasicStroke;
import java.awt.Font;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.TreeSet;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.NumberTickUnit;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.statistics.BoxAndWhiskerItem;
import org.jfree.data.statistics.DefaultBoxAndWhiskerCategoryDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

/**
*
* @author wangqion
*/
public class TaxaSimilarityMain {
   
    public static String[] RANKS = { "norank", "domain", "phylum", "class", "order", "family", "genus"};
    private ArrayList<Short> withinLowestRankSabSet = new ArrayList<Short>();
    private ArrayList<Short> diffLowestRankSabSet = new ArrayList<Short>();
    private List<String> ranks = new ArrayList<String>();
    private DecimalFormat format = new DecimalFormat("#.###");
    private HashMap<String, long[]> sabCoutMap = new HashMap<String, long[]>()// key = rank, value, count of the sab scores
    private final int BINSIZE = 101;
    private ScoringMatrix scoringMatrix = ScoringMatrix.getDefaultNuclMatrix();
    private AlignmentMode mode = AlignmentMode.overlap_trim;
    private static DistanceModel dist = new IdentityDistanceModel(true);

   
    public TaxaSimilarityMain( List<String> selectedRanks){
        for ( String r: selectedRanks){
            this.ranks.add(r.toLowerCase());
        }
        for ( String rank: ranks){
            sabCoutMap.put(rank.toLowerCase(), new long[BINSIZE]);
        }
       
    }
   
    public static List<String> readRanks(String rankFile) throws IOException {
        List<String> ranks = new ArrayList();
        BufferedReader reader = new BufferedReader(new FileReader(new File(rankFile)));
        String line = null;
        while ( (line=reader.readLine()) != null){
            ranks.add(line.trim());
        }
        return ranks;
    }
   
    public HashMap<String,HierarchyTree> getAncestorNodes(HierarchyTree root, String seqName, List<String> ancestors){
        HashMap<String,HierarchyTree> ancestorNodes = new HashMap<String,HierarchyTree>();
        if ( !ancestors.get(0).equals(root.getName())){
            throw new IllegalArgumentException("Sequence " + seqName + " does not have the same root taxon" + root.getName());
        }
        ancestorNodes.put(root.getTaxonomy().getHierLevel(), root);
        HierarchyTree curParent = root;
        for (int i = 1; i < ancestors.size(); i++){
           
            HierarchyTree  node = curParent.getSubclassbyName(ancestors.get(i));
            if ( node == null){
                throw new IllegalArgumentException("Sequence " + seqName + " cannot find ancestor node: " + ancestors.get(i));
            }
            ancestorNodes.put(node.getTaxonomy().getHierLevel().toLowerCase(), node);
            curParent = node;
        }       
        return ancestorNodes;
    }
   
   
   
    public void calSabSimilarity(String taxonFile, String trainSeqFile, String testSeqFile) throws IOException{       
        TreeFactory factory = new TreeFactory(new FileReader(taxonFile));
        factory.buildTree();
        // get the lineage of the trainSeqFile 
        LineageSequenceParser trainParser = new LineageSequenceParser(new File(trainSeqFile));
        HashMap<String, List<String>> lineageMap = new HashMap<String, List<String>>();
        while (trainParser.hasNext()) {
            LineageSequence seq = (LineageSequence) trainParser.next();
            lineageMap.put(seq.getSeqName(), seq.getAncestors());
           
         }
        trainParser.close();
        NuclSeqMatch sabCal = new NuclSeqMatch(trainSeqFile);
        LineageSequenceParser parser = new LineageSequenceParser(new File(testSeqFile));

        int count = 0;
        while (parser.hasNext()) {
            LineageSequence seq = (LineageSequence) parser.next();
            HashMap<String,HierarchyTree> queryAncestorNodes = getAncestorNodes(factory.getRoot(), seq.getSeqName(), seq.getAncestors());
           TreeSet<KmerMatchCore.BestMatch> matchResults = sabCal.findAllMatches(seq);
           
            short withinLowestRankSab = -1;
            short diffLowestRankSab = -1
            String bestDiffLowestRankMatch = null;
            for (KmerMatchCore.BestMatch match: matchResults){
                if ( match.getBestMatch().getSeqName().equals(seq.getSeqName())) continue;
                short sab = (short)(Math.round(100*match.getSab()));
                HashMap<String,HierarchyTree> matchAncestorNodes = getAncestorNodes(factory.getRoot(), match.getBestMatch().getSeqName(), lineageMap.get(match.getBestMatch().getSeqName()));
                boolean withinTaxon = false;
                for (int i = ranks.size() -1; i >=0; i--){                   
                    HierarchyTree queryTaxon = queryAncestorNodes.get( ranks.get(i));
                    HierarchyTree matchTaxon = matchAncestorNodes.get( ranks.get(i));
                    if ( queryTaxon != null && matchTaxon != null){
                        if ( queryTaxon.getName().equals(matchTaxon.getName())){
                            if ( !withinTaxon){  // if the query and match are not in the same child taxon, add sab to the current taxon
                                (sabCoutMap.get(ranks.get(i)))[sab]++;
                            }
                            withinTaxon = true;                           
                        }else {
                            withinTaxon = false;
                        }
                    }
                   
                } 
               
                // find within or different lowest level rank sab score, be either species or genus or any rank
                HierarchyTree speciesQueryTaxon = queryAncestorNodes.get( ranks.get(ranks.size()-1));   
                HierarchyTree speciesMatchTaxon = matchAncestorNodes.get( ranks.get(ranks.size()-1));
               
                if ( speciesQueryTaxon != null && speciesMatchTaxon != null && speciesQueryTaxon.getName().equals(speciesMatchTaxon.getName())){
                    withinLowestRankSab = sab >= withinLowestRankSab ? sab: withinLowestRankSab;
                }else {
                   
                    if ( sab >= diffLowestRankSab ){
                        bestDiffLowestRankMatch = match.getBestMatch().getSeqName();
                        diffLowestRankSab = sab;
                    }
                }
            }
            if ( withinLowestRankSab > 0){
                withinLowestRankSabSet.add(withinLowestRankSab);
            }
            if ( diffLowestRankSab > 0 ){
                diffLowestRankSabSet.add(diffLowestRankSab);
            }
            //System.out.println(seq.getSeqName() + "\t" + withinLowestRankSab + "\t" + diffLowestRankSab );
            count++;
            if ( count % 100 == 0){
                System.out.println(count);
            }
        }
        parser.close();
   
    }
   
    public void calPairwiseSimilaritye(String taxonFile, String trainSeqFile, String testSeqFile) throws IOException, OverlapCheckFailedException{       
        TreeFactory factory = new TreeFactory(new FileReader(taxonFile));
        factory.buildTree();
        // get the lineage of the trainSeqFile 
        LineageSequenceParser trainParser = new LineageSequenceParser(new File(trainSeqFile));
        ArrayList<LineageSequence> trainSeqList = new ArrayList<LineageSequence>();
        while (trainParser.hasNext()) {
            LineageSequence seq = (LineageSequence) trainParser.next();
            trainSeqList.add(seq);
         }
        trainParser.close();
        LineageSequenceParser parser = new LineageSequenceParser(new File(testSeqFile));

        while (parser.hasNext()) {
            LineageSequence seq = (LineageSequence) parser.next();
            HashMap<String,HierarchyTree> queryAncestorNodes = getAncestorNodes(factory.getRoot(), seq.getSeqName(), seq.getAncestors());
           
            for (LineageSequence trainSeq: trainSeqList){
                if ( trainSeq.getSeqName().equals(seq.getSeqName())) continue;
               
                HashMap<String,HierarchyTree> matchAncestorNodes = getAncestorNodes(factory.getRoot(), trainSeq.getSeqName(), trainSeq.getAncestors());
                boolean withinTaxon = false;
                String lowestCommonRank = null;
                for (int i = ranks.size() -1; i >=0; i--){                   
                    HierarchyTree queryTaxon = queryAncestorNodes.get( ranks.get(i));
                    HierarchyTree matchTaxon = matchAncestorNodes.get( ranks.get(i));
                    if ( queryTaxon != null && matchTaxon != null){
                        if ( queryTaxon.getName().equals(matchTaxon.getName())){
                            if ( !withinTaxon){  // if the query and match are not in the same child taxon, add sab to the current taxon
                                lowestCommonRank = ranks.get(i);
                                //(sabCoutMap.get(ranks.get(i)))[sab]++;
                            }
                            withinTaxon = true;                           
                        }else {
                            withinTaxon = false;
                        }
                    }
                } 
               
                if ( lowestCommonRank == null){  // not the rank we care
                    continue;
                }

                // we need to use overlap_trim mode and calculate distance as metric to count insertions, deletions and mismatches.
                PairwiseAlignment result = PairwiseAligner.align(seq.getSeqString().replaceAll("U", "T"), trainSeq.getSeqString().replaceAll("U", "T"), scoringMatrix, mode);
                short sab = (short) (100 - 100*dist.getDistance(result.getAlignedSeqj().getBytes(), result.getAlignedSeqi().getBytes(), 0));
                sabCoutMap.get(lowestCommonRank)[sab]++;
                               
            }          
        }
        parser.close();
   
    }
   
    public void createPlot(String plotTitle, File outdir) throws IOException{
        XYSeriesCollection dataset = new XYSeriesCollection();
        DefaultBoxAndWhiskerCategoryDataset scatterDataset = new DefaultBoxAndWhiskerCategoryDataset();

        PrintStream boxchart_dataStream = new PrintStream(new File(outdir, plotTitle + ".boxchart.txt"));
       
        boxchart_dataStream.println("#\tkmer" + "\trank" + "\t" + "max" + "\t" + "avg" + "\t" + "min" +
                "\t" + "Q1" + "\t" + "median" + "\t" + "Q3" + "\t" + "98Pct" + "\t" + "2Pct" + "\t" + "comparisons" + "\t" + "sum");
        for ( int i = 0; i < ranks.size(); i++){
            long[] countArray = sabCoutMap.get(ranks.get(i));
            if ( countArray == null) continue;
           
            double sum = 0.0;
            int max = 0;
            int min = 100;
            double mean = 0;
            int Q1 = -1;
            int median = -1;
            int Q3 = -1;
            int pct_98 =-1;
            int pct_2 = -1;
            long comparisons = 0;
            int minOutlier = 0// we don't care about the outliers
            int maxOutlier = 0//
           
            XYSeries series = new XYSeries(ranks.get(i));
           
            for ( int c = 0; c< countArray.length; c++){
                if ( countArray[c] == 0 ) continue;
                comparisons += countArray[c];
                sum += countArray[c] * c;
                if ( c < min){
                    min = c;
                }
                if ( c > max){
                    max = c;
                }
            }
           
            // create series
            double cum = 0;
            for ( int c = 0; c< countArray.length; c++){
                if ( countArray[c] == 0 ) continue;
                cum += countArray[c];
                int pct = (int) Math.floor(100*cum/comparisons);
                series.add(c, pct);
               
                if ( pct_2 == -1 && pct >=5){
                    pct_2 = c;
                }
                if ( Q3 == -1 && pct >=25){
                    Q3 = c;
                }
                if ( median == -1 && pct >=50){
                    median = c;
                }
                if ( Q1 == -1 && pct >=75){
                    Q1 = c;
                }
                if ( pct_98 == -1 && pct >=98){
                    pct_98 = c;
                }
            }
            if ( !series.isEmpty()) {
                dataset.addSeries(series);
               
                BoxAndWhiskerItem item = new BoxAndWhiskerItem(sum/comparisons, median, Q1, Q3, pct_2, pct_98,  minOutlier,  maxOutlier, new ArrayList());
                scatterDataset.add(item, ranks.get(i), "");
               
                boxchart_dataStream.println("#\t" + GoodWordIterator.getWordsize() + "\t" + ranks.get(i) + "\t" + max + "\t" + format.format(sum/comparisons) + "\t" + min +
                "\t" + Q1 + "\t" + median + "\t" + Q3 + "\t" + pct_98 + "\t" + pct_2 + "\t" + comparisons + "\t" + sum);
            }
        } 
        boxchart_dataStream.close();      
        Font lableFont = new Font("Helvetica", Font.BOLD, 28);
       
        JFreeChart chart = ChartFactory.createXYLineChart(plotTitle, "Similarity%", "Percent Comparisions",  dataset,  PlotOrientation.VERTICAL, true, true, false  );
        ((XYPlot) chart.getPlot()).getRenderer().setStroke( new BasicStroke( 2.0f ));
        chart.getLegend().setItemFont(new Font("Helvetica", Font.BOLD, 24));
        chart.getTitle().setFont(lableFont);
        ((XYPlot) chart.getPlot()).getDomainAxis().setLabelFont(lableFont);
        ((XYPlot) chart.getPlot()).getDomainAxis().setTickLabelFont(lableFont);
        ValueAxis rangeAxis = ((XYPlot) chart.getPlot()).getRangeAxis();
        rangeAxis.setRange(0,100);
        rangeAxis.setTickLabelFont(lableFont);
        rangeAxis.setLabelFont(lableFont);
        ((NumberAxis)rangeAxis).setTickUnit(new NumberTickUnit(5));
        ChartUtilities.writeScaledChartAsPNG(new PrintStream(new File(outdir, plotTitle + ".linechart.png")), chart, 800, 1000, 3, 3);

        BoxPlotUtils.createBoxplot(scatterDataset, new PrintStream(new File(outdir, plotTitle + ".boxchart.png")), plotTitle, "Rank", "Similarity%", lableFont);

    }
   
 
   
       
    /**
     * This calculates the average similarity (Sab score or pairwise alignment) between taxa at given ranks and plot the box and whisker plot and accumulation curve.
     * The distances associate to a given rank contains the distances between different child taxa. It does not include the distances within the same child taxa.
     * For example, if a query and it's closest match are from the same genus, the distance value is added to that genus.
     * If there are from different genera but the same family, the distance value is added to that family, etc.
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException, OverlapCheckFailedException{
        String usage = "Usage: taxonfile trainset.fasta query.fasta outdir kmersize rankFile sab|pw \n" +
                "  This program calculates the average similarity (Sab score, or pairwise alignment) within taxa\n" +
                "  and plot the box and whisker plot and accumulation curve plot. \n" +
                "  rankFile: a file contains a list of ranks to be calculated and plotted. One rank per line, no particular order required. \n" +
                "  Note pw is extremely slower, recommended only for lower ranks such as species, genus and family. ";
       
       
        if ( args.length != 7 ){
            System.err.println(usage);
            System.exit(1);
        }
        List<String> ranks = readRanks(args[5]);      
        File outdir = new File(args[3]);
        if ( !outdir.isDirectory()){
            System.err.println("outdir must be a directory");
            System.exit(1);
        }
        int kmer = Integer.parseInt(args[4]);
        GoodWordIterator.setWordSize(kmer);       
        TaxaSimilarityMain theObj = new TaxaSimilarityMain(ranks);

        String plotTitle = new File(args[2]).getName();
        int index = plotTitle.indexOf(".");
        if ( index != -1){
            plotTitle = plotTitle.substring(0, index);         
        }
        if ( args[6].equalsIgnoreCase("sab")){
            theObj.calSabSimilarity(args[0], args[1], args[2]);
        }else {
            theObj.calPairwiseSimilaritye(args[0], args[1], args[2]);
        }
       
        theObj.createPlot(plotTitle, outdir);
       
    }
   
}
TOP

Related Classes of edu.msu.cme.rdp.classifier.train.validation.distance.TaxaSimilarityMain

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.