Package pattern.model.lm

Source Code of pattern.model.lm.RegressionModel

/*
* Copyright (c) 2007-2013 Concurrent, Inc. All Rights Reserved.
*
* Project and contact information: http://www.concurrentinc.com/
*/

package pattern.model.lm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import pattern.PMML;
import pattern.PatternException;
import pattern.model.Model;
import pattern.predictor.Predictor;
import pattern.predictor.PredictorFactory;
import storm.trident.tuple.TridentTuple;

public class RegressionModel extends Model implements Serializable {
  /** Field LOG */
  private static final Logger LOG = LoggerFactory
      .getLogger(RegressionModel.class);

  public Double intercept = 0.0;
  public List<Predictor> predictors = new ArrayList<Predictor>();

  /**
   * Constructor for a RegressionModel as a standalone classifier (PMML
   * versions 1-3).
   *
   * @param pmml
   *            PMML model
   * @throws PatternException
   */
  public RegressionModel(PMML pmml) throws PatternException {
    schema = pmml.getSchema();
    schema.parseMiningSchema(pmml
        .getNodeList("/PMML/RegressionModel/MiningSchema/MiningField"));

    String node_expr = "/PMML/RegressionModel/RegressionTable[1]";
    Element model_node = (Element) pmml.getNodeList(node_expr).item(0);

    intercept = Double.valueOf(model_node.getAttribute("intercept"));
    LOG.debug("Intercept: " + intercept);

    NodeList child_nodes = model_node.getChildNodes();

    for (int i = 0; i < child_nodes.getLength(); i++) {
      Node child = child_nodes.item(i);

      if (child.getNodeType() == Node.ELEMENT_NODE) {
        Predictor pred = PredictorFactory.getPredictor(schema,
            (Element) child);
        predictors.add(pred);
        LOG.debug(pred.toString());
      }
    }
  }

  /**
   * Prepare to classify with this model. Called immediately before the
   * enclosing Operation instance is put into play processing Tuples.
   */
  @Override
  public void prepare() {
    // not needed
  }

  /**
   * Classify an input tuple, returning the predicted label.
   *
   *
   * @param values
   *            tuple values
   * @param fields
   *            field names
   * @return String
   * @throws PatternException
   */
  @Override
  public String classifyTuple(TridentTuple values) throws PatternException {
    Map<String, Object> param_map = schema.getParamMap(values);
    double result = intercept;

    for (Predictor pred : predictors) {
      double term = pred.calcTerm(param_map);
      result += term;
    }

    LOG.debug("result: " + result);

    return Double.toString(result);
  }

  /** @return String */
  @Override
  public String toString() {
    StringBuilder buf = new StringBuilder();

    if (schema != null) {
      buf.append(schema);
      buf.append("\n");
      buf.append("---------");
      buf.append("\n");
      buf.append(predictors);
      buf.append("---------");
      buf.append("\n");
    }

    return buf.toString();
  }
}
TOP

Related Classes of pattern.model.lm.RegressionModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.