Package de.lmu.ifi.dbs.elki.visualization.visualizers.visunproj

Source Code of de.lmu.ifi.dbs.elki.visualization.visualizers.visunproj.CurveVisFactory

package de.lmu.ifi.dbs.elki.visualization.visualizers.visunproj;

/*
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures

Copyright (C) 2011
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

import java.util.Collection;
import java.util.Iterator;

import org.apache.batik.util.SVGConstants;
import org.w3c.dom.Element;

import de.lmu.ifi.dbs.elki.evaluation.roc.ComputeROCCurve;
import de.lmu.ifi.dbs.elki.evaluation.roc.ComputeROCCurve.ROCResult;
import de.lmu.ifi.dbs.elki.logging.LoggingUtil;
import de.lmu.ifi.dbs.elki.math.DoubleMinMax;
import de.lmu.ifi.dbs.elki.result.HierarchicalResult;
import de.lmu.ifi.dbs.elki.result.IterableResult;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.result.ResultUtil;
import de.lmu.ifi.dbs.elki.utilities.FormatUtil;
import de.lmu.ifi.dbs.elki.utilities.iterator.AbstractFilteredIterator;
import de.lmu.ifi.dbs.elki.utilities.iterator.IterableIterator;
import de.lmu.ifi.dbs.elki.utilities.pairs.DoubleDoublePair;
import de.lmu.ifi.dbs.elki.visualization.VisualizationTask;
import de.lmu.ifi.dbs.elki.visualization.VisualizerContext;
import de.lmu.ifi.dbs.elki.visualization.css.CSSClass;
import de.lmu.ifi.dbs.elki.visualization.css.CSSClassManager.CSSNamingConflict;
import de.lmu.ifi.dbs.elki.visualization.scales.LinearScale;
import de.lmu.ifi.dbs.elki.visualization.style.StyleLibrary;
import de.lmu.ifi.dbs.elki.visualization.svg.SVGPath;
import de.lmu.ifi.dbs.elki.visualization.svg.SVGPlot;
import de.lmu.ifi.dbs.elki.visualization.svg.SVGSimpleLinearAxis;
import de.lmu.ifi.dbs.elki.visualization.svg.SVGUtil;
import de.lmu.ifi.dbs.elki.visualization.visualizers.AbstractVisFactory;
import de.lmu.ifi.dbs.elki.visualization.visualizers.StaticVisualization;
import de.lmu.ifi.dbs.elki.visualization.visualizers.Visualization;

/**
* Visualizer to render a simple 2D curve such as a ROC curve.
*
* @author Erich Schubert
*
* @apiviz.stereotype factory
* @apiviz.uses StaticVisualization oneway - - «create»
* @apiviz.has IterableResult oneway - - visualizes
*/
public class CurveVisFactory extends AbstractVisFactory {
  /**
   * Name for this visualizer.
   */
  private static final String NAME = "Curve";

  /**
   * SVG class name for plot line
   */
  private static final String SERIESID = "series";

  /**
   * Constructor, Parameterizable style - does nothing.
   */
  public CurveVisFactory() {
    super();
  }

  @Override
  public Visualization makeVisualization(VisualizationTask task) {
    VisualizerContext context = task.getContext();
    SVGPlot svgp = task.getPlot();
    IterableResult<DoubleDoublePair> curve = task.getResult();

    setupCSS(context, svgp);
    double scale = StyleLibrary.SCALE;
    final double sizex = scale;
    final double sizey = scale * task.getHeight() / task.getWidth();
    final double margin = context.getStyleLibrary().getSize(StyleLibrary.MARGIN);
    Element layer = SVGUtil.svgElement(svgp.getDocument(), SVGConstants.SVG_G_TAG);
    final String transform = SVGUtil.makeMarginTransform(task.getWidth(), task.getHeight(), sizex, sizey, margin);
    SVGUtil.setAtt(layer, SVGConstants.SVG_TRANSFORM_ATTRIBUTE, transform);

    // determine scaling
    DoubleMinMax minmaxx = new DoubleMinMax();
    DoubleMinMax minmaxy = new DoubleMinMax();
    for(DoubleDoublePair pair : curve) {
      minmaxx.put(pair.first);
      minmaxy.put(pair.second);
    }
    LinearScale scalex = new LinearScale(minmaxx.getMin(), minmaxx.getMax());
    LinearScale scaley = new LinearScale(minmaxy.getMin(), minmaxy.getMax());
    // plot the line
    SVGPath path = new SVGPath();
    for(DoubleDoublePair pair : curve) {
      final double x = scalex.getScaled(pair.first);
      final double y = 1 - scaley.getScaled(pair.second);
      path.drawTo(sizex * x, sizey * y);
    }
    Element line = path.makeElement(svgp);
    line.setAttribute(SVGConstants.SVG_CLASS_ATTRIBUTE, SERIESID);

    // add axes
    try {
      SVGSimpleLinearAxis.drawAxis(svgp, layer, scalex, 0, sizey, sizex, sizey, true, true, context.getStyleLibrary());
      SVGSimpleLinearAxis.drawAxis(svgp, layer, scaley, 0, sizey, 0, 0, true, false, context.getStyleLibrary());
    }
    catch(CSSNamingConflict e) {
      LoggingUtil.exception(e);
    }

    // Add AUC value when found
    if(curve instanceof ROCResult) {
      Collection<String> header = ((ROCResult) curve).getHeader();
      for(String str : header) {
        String[] parts = str.split(":\\s*");
        if(parts[0].equals(ComputeROCCurve.ROCAUC_LABEL) && parts.length == 2) {
          double rocauc = Double.parseDouble(parts[1]);
          StyleLibrary style = context.getStyleLibrary();
          CSSClass cls = new CSSClass(svgp, "unmanaged");
          String lt = "ROC AUC: " + FormatUtil.NF8.format(rocauc);
          double fontsize = style.getTextSize("curve.labels");
          cls.setStatement(SVGConstants.CSS_FONT_SIZE_PROPERTY, SVGUtil.fmt(fontsize));
          cls.setStatement(SVGConstants.CSS_FILL_PROPERTY, style.getTextColor("curve.labels"));
          cls.setStatement(SVGConstants.CSS_FONT_FAMILY_PROPERTY, style.getFontFamily("curve.labels"));
          if(rocauc <= 0.5) {
            Element auclbl = svgp.svgText(sizex * 0.95, sizey * 0.95, lt);
            SVGUtil.setAtt(auclbl, SVGConstants.SVG_STYLE_ATTRIBUTE, cls.inlineCSS());
            // SVGUtil.setAtt(auclbl, SVGConstants.SVG_TEXT_ANCHOR_ATTRIBUTE,
            // SVGConstants.SVG_START_VALUE);
            layer.appendChild(auclbl);
          }
          else {
            Element auclbl = svgp.svgText(sizex * 0.95, sizey * 0.95, lt);
            SVGUtil.setAtt(auclbl, SVGConstants.SVG_STYLE_ATTRIBUTE, cls.inlineCSS());
            SVGUtil.setAtt(auclbl, SVGConstants.SVG_TEXT_ANCHOR_ATTRIBUTE, SVGConstants.SVG_END_VALUE);
            layer.appendChild(auclbl);
          }
        }
      }
    }

    layer.appendChild(line);
    return new StaticVisualization(task, layer);
  }

  /**
   * Setup the CSS classes for the plot.
   *
   * @param svgp Plot
   */
  private void setupCSS(VisualizerContext context, SVGPlot svgp) {
    CSSClass csscls = new CSSClass(this, SERIESID);
    // csscls.setStatement(SVGConstants.SVG_STROKE_WIDTH_ATTRIBUTE, "0.2%");
    csscls.setStatement(SVGConstants.SVG_FILL_ATTRIBUTE, SVGConstants.SVG_NONE_VALUE);
    context.getStyleLibrary().lines().formatCSSClass(csscls, 0, context.getStyleLibrary().getLineWidth(StyleLibrary.PLOT));
    svgp.addCSSClassOrLogError(csscls);
  }

  /**
   * Filter to only retrieve double-double-pair results.
   *
   * @author Erich Schubert
   *
   * @apiviz.exclude
   */
  class CurveFilter extends AbstractFilteredIterator<IterableResult<?>, IterableResult<DoubleDoublePair>> implements IterableIterator<IterableResult<DoubleDoublePair>> {
    /**
     * Parent iterator to use
     */
    Iterator<IterableResult<?>> parent;

    /**
     * Constructor.
     *
     * @param parent Parent iterator to decorate.
     */
    public CurveFilter(Iterator<IterableResult<?>> parent) {
      super();
      this.parent = parent;
    }

    @Override
    protected Iterator<IterableResult<?>> getParentIterator() {
      return parent;
    }

    @Override
    protected IterableResult<DoubleDoublePair> testFilter(IterableResult<?> nextobj) {
      Iterator<?> iterator = nextobj.iterator();
      if(iterator.hasNext()) {
        Object o = iterator.next();
        if(o instanceof DoubleDoublePair) {
          @SuppressWarnings("unchecked")
          final IterableResult<DoubleDoublePair> ret = (IterableResult<DoubleDoublePair>) nextobj;
          return ret;
        }
      }
      return null;
    }

    @Override
    public Iterator<IterableResult<DoubleDoublePair>> iterator() {
      return this;
    }
  }

  @Override
  public void processNewResult(HierarchicalResult baseResult, Result result) {
    final IterableIterator<IterableResult<?>> iterableResults = ResultUtil.filteredResults(result, IterableResult.class);
    final IterableIterator<IterableResult<DoubleDoublePair>> curves = new CurveFilter(iterableResults);
    for (IterableResult<DoubleDoublePair> curve : curves) {
      final VisualizationTask task = new VisualizationTask(NAME, curve, null, this);
      task.width = 1.0;
      task.height = 1.0;
      task.put(VisualizationTask.META_LEVEL, VisualizationTask.LEVEL_STATIC);
      baseResult.getHierarchy().add(curve, task);
    }
  }

  @Override
  public boolean allowThumbnails(VisualizationTask task) {
    // TODO: depending on the curve complexity?
    return false;
  }
}
TOP

Related Classes of de.lmu.ifi.dbs.elki.visualization.visualizers.visunproj.CurveVisFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.