Package cc.mallet.types

Source Code of cc.mallet.types.HashedSparseVector

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   Sparse, yet its (present) values can be changed.  You can't, however, add
   values that were (zero and) missing.
  
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.types;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Arrays;
import java.util.logging.*;
import java.io.*;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Vector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.PropertyList;
import gnu.trove.TIntIntHashMap;

public class HashedSparseVector extends SparseVector implements Serializable
{
  private static Logger logger = MalletLogger.getLogger(SparseVector.class.getName());

 
  TIntIntHashMap index2location;
  int maxIndex;
 
  public HashedSparseVector (int[] indices, double[] values,
                       int capacity, int size,
                       boolean copy,
                       boolean checkIndicesSorted,
                       boolean removeDuplicates)
  {
    super (indices, values, capacity, size, copy, checkIndicesSorted, removeDuplicates);
    assert (indices != null);
  }

  /** Create an empty vector */
  public HashedSparseVector ()
  {
    super (new int[0], new double[0], 0, 0, false, false, false);
  }

  /** Create non-binary vector, possibly dense if "featureIndices" or possibly sparse, if not */
  public HashedSparseVector (int[] featureIndices,
                       double[] values)
  {
    super (featureIndices, values);
  }

  /** Create binary vector */
  public HashedSparseVector (int[] featureIndices)
  {
    super (featureIndices);
  }

  // xxx We need to implement this in FeatureVector subclasses
  public ConstantMatrix cloneMatrix ()
  {
    return new HashedSparseVector (indices, values);
  }

  public ConstantMatrix cloneMatrixZeroed () {
    assert (values != null);
    int[] newIndices = new int[indices.length];
    System.arraycopy (indices, 0, newIndices, 0, indices.length);
    HashedSparseVector sv = new HashedSparseVector (newIndices, new double[values.length],
                             values.length, values.length, false, false, false);
    // share index2location trick ala IndexedSparseVector
    if (index2location != null) {
      sv.index2location = index2location;
      sv.maxIndex = maxIndex;
    }
    return sv;
  }
 
  // Methods that change values

  public void indexVector ()
  {
    if ((index2location == null) && (indices.length > 0))
      setIndex2Location ();
  }

  private void setIndex2Location ()
  {
    //System.out.println ("HashedSparseVector setIndex2Location indices.length="+indices.length+" maxindex="+indices[indices.length-1]);
    assert (index2location == null);
    assert (indices.length > 0);
    this.maxIndex = indices[indices.length - 1];
    this.index2location = new TIntIntHashMap (numLocations ());
    //index2location.setDefaultValue (-1);
    for (int i = 0; i < indices.length; i++)
      index2location.put (indices[i], i);
  }

  public final void setValue (int index, double value) {
    if (index2location == null)
      setIndex2Location ();
    int location = index2location.get(index);
    if (index2location.contains (index))
      values[location] = value;
    else
      throw new IllegalArgumentException ("Trying to set value that isn't present in HashedSparseVector");
  }

  public final void setValueAtLocation (int location, double value)
  {
    values[location] = value;
  }

  // I dislike this name, but it's consistent with DenseVector. -cas
  public void columnPlusEquals (int index, double value) {
    if (index2location == null)
      setIndex2Location ();
    int location = index2location.get(index);
    if (index2location.contains (index))
      values[location] += value;
    else
      throw new IllegalArgumentException ("Trying to set value that isn't present in HashedSparseVector");
  }
   
  public final double dotProduct (DenseVector v) {
    double ret = 0;
    if (values == null)
      for (int i = 0; i < indices.length; i++)
        ret += v.value(indices[i]);
    else
      for (int i = 0; i < indices.length; i++)
        ret += values[i] * v.value(indices[i]);
    return ret;
  }
   
    public final double dotProduct (SparseVector v)
    {
  if (indices.length == 0)
      return 0;
  if (index2location == null)
      setIndex2Location ();
  double ret = 0;
  int vNumLocs = v.numLocations();
  if (values == null) {
      // this vector is binary
      for (int i = 0; i < vNumLocs; i++) {
    int index = v.indexAtLocation(i);
    if (index > maxIndex)
        break;
    if (index2location.contains(index))
        ret += v.valueAtLocation (i);
      }
  } else {
      for (int i = 0; i < vNumLocs; i++) {
    int index = v.indexAtLocation(i);
    if (index > maxIndex)
        break;
   
    if (index2location.containsKey(index)) {
        ret += values[ index2location.get(index) ] * v.valueAtLocation (i);
    }
   
   
    //int location = index2location.get(index);
    //if (location >= 0)
    //  ret += values[location] * v.valueAtLocation (i);
      }
  }
  return ret;
    }
 
    public final void plusEqualsSparse (SparseVector v, double factor)
    {
  if (indices.length == 0)
      return;
  if (index2location == null)
      setIndex2Location ();
  int vNumLocs = v.numLocations();
  for (int i = 0; i < vNumLocs; i++) {
      int index = v.indexAtLocation(i);
      if (index > maxIndex)
    break;
     
      if (index2location.containsKey(index)) {
    values[ index2location.get(index) ] += v.valueAtLocation (i) * factor;
      }
     
      //int location = index2location.get(index);
      //if (location >= 0)
      //      values[location] += v.valueAtLocation (i) * factor;
  }
    }

  public final void plusEqualsSparse (SparseVector v)
  {
    if (indices.length == 0)
      return;
    if (index2location == null)
      setIndex2Location ();
    for (int i = 0; i < v.numLocations(); i++) {
      int index = v.indexAtLocation(i);
      if (index > maxIndex)
        break;
      int location = index2location.get(index);
      if (index2location.contains (index))
        values[location] += v.valueAtLocation (i);
    }
  }
 
  public final void setAll (double v)
  {
    for (int i = 0; i < values.length; i++)
      values[i] = v;
  }


 
  //Serialization

  private static final long serialVersionUID = 1;

  // Version history:
  //   0 == Wrote out index2location.  Probably a bad idea.
  private static final int CURRENT_SERIAL_VERSION = 1;
  static final int NULL_INTEGER = -1;

  private void writeObject (ObjectOutputStream out) throws IOException {
    out.writeInt (CURRENT_SERIAL_VERSION);
    out.writeInt (maxIndex);
  }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    int version = in.readInt ();
    maxIndex = in.readInt ();

    if (version == 0) {
      // gobble up index2location
      Object obj = in.readObject ();
      if (obj != null && !(obj instanceof TIntIntHashMap)) {
        throw new IOException ("Unexpected object in de-serialization: "+obj);
      }
    }

  }

}
TOP

Related Classes of cc.mallet.types.HashedSparseVector

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.