Package org.fnlp.ml.classifier.linear.update

Source Code of org.fnlp.ml.classifier.linear.update.AbstractPAUpdate

/**
*  This file is part of FNLP (formerly FudanNLP).
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  Copyright 2009-2014 www.fnlp.org. All rights reserved.
*/

package org.fnlp.ml.classifier.linear.update;

import org.fnlp.ml.loss.Loss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.sv.HashSparseVector;

/**
* 抽象参数更新类,采用PA算法
* \mathbf{w_{t+1}} = \w_t + {\alpha^*(\Phi(x,y)- \Phi(x,\hat{y}))}.
* \alpha =\frac{1- \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right)}{||\Phi(x,y) - \Phi(x,\hat{y})||^2}.
*
*/
public abstract class AbstractPAUpdate implements Update {

  /**
   * \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right)
   */
  protected float diffw;
  /**
   * \Phi(x,y)- \Phi(x,\hat{y})
   */
  protected HashSparseVector diffv;
  protected Loss loss;
  /**
   * 是否使用样本的权重进行加权
   */
  public boolean useInstWeight;

  public AbstractPAUpdate(Loss loss) {
    diffw = 0;
    diffv = new HashSparseVector();
    this.loss = loss;
  }

  @Override
  public float update(Instance inst, float[] weights, int k, float[] extraweight, Object predict, float c) {
    return update(inst, weights, k, extraweight, inst.getTarget(), predict, c);
  }
 

  @Override
  public float update(Instance inst, float[] weights, int k, float[] extraweight, Object target,
      Object predict, float c) {

    int lost = diff(inst, weights, target, predict);
    if(lost==0)
      return 0f;
    float lamda = diffv.l2Norm2();

    if (diffw <= lost) {
      float alpha = (lost - diffw) / lamda;
      if(useInstWeight)
        alpha = alpha*inst.getWeight();
      if(alpha>c){
        alpha = c;
      }
     
      int[] idx = diffv.indices();
      
      for (int i = 0; i < idx.length; i++) {     
        float t = diffv.get(idx[i]) * alpha;
        weights[idx[i]] += t;
        extraweight[idx[i]] += t *k;
      }
      for (int i = 0; i < idx.length; i++) {       
       
      }
    }
   
    diffv.clear();
    diffw = 0;
   
    return loss.calc(target, predict);
  }

  /**
   * 计算预测类别和对照类别之间的距离
   * @param inst 样本实例
   * @param weights 权重
   * @param target 对照类别
   * @param predict 预测类别
   * @return 预测类别和对照类别之间的距离
   */
  protected abstract int diff(Instance inst, float[] weights, Object target,
      Object predict);

  protected void adjust(float[] weights, int ts, int ps) {
    assert (ts != -1 && ps != -1);
    diffv.put(ts, 1.0f);
    diffv.put(ps, -1.0f);
    diffw += weights[ts] - weights[ps];
  }
}
TOP

Related Classes of org.fnlp.ml.classifier.linear.update.AbstractPAUpdate

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.