Package org.apache.mahout.math

Source Code of org.apache.mahout.math.AbstractMatrix$TransposeViewVector

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.math;

import com.google.common.collect.Maps;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.PlusMult;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.VectorFunction;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;

/** A few universal implementations of convenience functions */
public abstract class AbstractMatrix implements Matrix {

  protected Map<String, Integer> columnLabelBindings;

  protected Map<String, Integer> rowLabelBindings;

  protected int[] cardinality = new int[2];

  @Override
  public int columnSize() {
    return cardinality[COL];
  }

  @Override
  public int rowSize() {
    return cardinality[ROW];
  }

  @Override
  public int[] size() {
    return cardinality;
  }

  @Override
  public Iterator<MatrixSlice> iterator() {
    return iterateAll();
  }

  @Override
  public Iterator<MatrixSlice> iterateAll() {
    return new Iterator<MatrixSlice>() {
      private int slice;

      @Override
      public boolean hasNext() {
        return slice < numSlices();
      }

      @Override
      public MatrixSlice next() {
        if (slice >= numSlices()) {
          throw new NoSuchElementException();
        }
        int i = slice++;
        return new MatrixSlice(slice(i), i);
      }

      @Override
      public void remove() {
        throw new UnsupportedOperationException("remove() not supported for Matrix iterator");
      }
    };
  }

  /**
   * Abstracted out for iterating over either rows or columns (default is rows).
   * @param index the row or column number to grab as a vector (shallowly)
   * @return the row or column vector at that index.
   */
  protected Vector slice(int index) {
    return getRow(index);
  }

  /**
   * Abstracted out for the iterator
   * @return numRows() for row-based iterator, numColumns() for column-based.
   */
  @Override
  public int numSlices() {
    return numRows();
  }

  @Override
  public double get(String rowLabel, String columnLabel) {
    if (columnLabelBindings == null || rowLabelBindings == null) {
      throw new UnboundLabelException();
    }
    Integer row = rowLabelBindings.get(rowLabel);
    Integer col = columnLabelBindings.get(columnLabel);
    if (row == null || col == null) {
      throw new UnboundLabelException();
    }

    return get(row, col);
  }

  @Override
  public Map<String, Integer> getColumnLabelBindings() {
    return columnLabelBindings;
  }

  @Override
  public Map<String, Integer> getRowLabelBindings() {
    return rowLabelBindings;
  }

  @Override
  public void set(String rowLabel, double[] rowData) {
    if (columnLabelBindings == null) {
      throw new UnboundLabelException();
    }
    Integer row = rowLabelBindings.get(rowLabel);
    if (row == null) {
      throw new UnboundLabelException();
    }
    set(row, rowData);
  }

  @Override
  public void set(String rowLabel, int row, double[] rowData) {
    if (rowLabelBindings == null) {
      rowLabelBindings = new HashMap<String, Integer>();
    }
    rowLabelBindings.put(rowLabel, row);
    set(row, rowData);
  }

  @Override
  public void set(String rowLabel, String columnLabel, double value) {
    if (columnLabelBindings == null || rowLabelBindings == null) {
      throw new UnboundLabelException();
    }
    Integer row = rowLabelBindings.get(rowLabel);
    Integer col = columnLabelBindings.get(columnLabel);
    if (row == null || col == null) {
      throw new UnboundLabelException();
    }
    set(row, col, value);
  }

  @Override
  public void set(String rowLabel, String columnLabel, int row, int column, double value) {
    if (rowLabelBindings == null) {
      rowLabelBindings = new HashMap<String, Integer>();
    }
    rowLabelBindings.put(rowLabel, row);
    if (columnLabelBindings == null) {
      columnLabelBindings = new HashMap<String, Integer>();
    }
    columnLabelBindings.put(columnLabel, column);

    set(row, column, value);
  }

  @Override
  public void setColumnLabelBindings(Map<String, Integer> bindings) {
    columnLabelBindings = bindings;
  }

  @Override
  public void setRowLabelBindings(Map<String, Integer> bindings) {
    rowLabelBindings = bindings;
  }

  // index into int[2] for column value
  public static final int COL = 1;

  // index into int[2] for row value
  public static final int ROW = 0;

  @Override
  public int numRows() {
    return size()[ROW];
  }

  @Override
  public int numCols() {
    return size()[COL];
  }

  @Override
  public String asFormatString() {
    return toString();
  }

  @Override
  public Matrix assign(double value) {
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        setQuick(row, col, value);
      }
    }
    return this;
  }

  @Override
  public Matrix assign(double[][] values) {
    int[] c = size();
    if (c[ROW] != values.length) {
      throw new CardinalityException(c[ROW], values.length);
    }
    for (int row = 0; row < c[ROW]; row++) {
      if (c[COL] == values[row].length) {
        for (int col = 0; col < c[COL]; col++) {
          setQuick(row, col, values[row][col]);
        }
      } else {
        throw new CardinalityException(c[COL], values[row].length);
      }
    }
    return this;
  }

  @Override
  public Matrix assign(Matrix other, DoubleDoubleFunction function) {
    int[] c = size();
    int[] o = other.size();
    if (c[ROW] != o[ROW]) {
      throw new CardinalityException(c[ROW], o[ROW]);
    }
    if (c[COL] != o[COL]) {
      throw new CardinalityException(c[COL], o[COL]);
    }
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(
            row, col)));
      }
    }
    return this;
  }

  @Override
  public Matrix assign(Matrix other) {
    int[] c = size();
    int[] o = other.size();
    if (c[ROW] != o[ROW]) {
      throw new CardinalityException(c[ROW], o[ROW]);
    }
    if (c[COL] != o[COL]) {
      throw new CardinalityException(c[COL], o[COL]);
    }
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        setQuick(row, col, other.getQuick(row, col));
      }
    }
    return this;
  }

  @Override
  public Matrix assign(DoubleFunction function) {
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        setQuick(row, col, function.apply(getQuick(row, col)));
      }
    }
    return this;
  }

  /**
   * Collects the results of a function applied to each row of a matrix.
   *
   * @param f The function to be applied to each row.
   * @return The vector of results.
   */
  @Override
  public Vector aggregateRows(VectorFunction f) {
    Vector r = new DenseVector(numRows());
    int n = numRows();
    for (int row = 0; row < n; row++) {
      r.set(row, f.apply(viewRow(row)));
    }
    return r;
  }

  /**
   * Returns a view of a row.  Changes to the view will affect the original.
   * @param row  Which row to return.
   * @return A vector that references the desired row.
   */
  @Override
  public Vector viewRow(int row) {
    return new MatrixVectorView(this, row, 0, 0, 1);
  }


  /**
   * Returns a view of a row.  Changes to the view will affect the original.
   * @param column Which column to return.
   * @return A vector that references the desired column.
   */
  @Override
  public Vector viewColumn(int column) {
    return new MatrixVectorView(this, 0, column, 1, 0);
  }

  /**
   * Collects the results of a function applied to each column of a matrix.
   *
   * @param f The function to be applied to each column.
   * @return The vector of results.
   */
  @Override
  public Vector aggregateColumns(VectorFunction f) {
    Vector r = new DenseVector(numCols());
    for (int col = 0; col < numCols(); col++) {
      r.set(col, f.apply(viewColumn(col)));
    }
    return r;
  }

  /**
   * Collects the results of a function applied to each element of a matrix and then aggregated.
   *
   * @param combiner A function that combines the results of the mapper.
   * @param mapper   A function to apply to each element.
   * @return The result.
   */
  @Override
  public double aggregate(final DoubleDoubleFunction combiner, final DoubleFunction mapper) {
    return aggregateRows(new VectorFunction() {
      @Override
      public double apply(Vector v) {
        return v.aggregate(combiner, mapper);
      }
    }).aggregate(combiner, Functions.IDENTITY);
  }

  @Override
  public double determinant() {
    int[] card = size();
    int rowSize = card[ROW];
    int columnSize = card[COL];
    if (rowSize != columnSize) {
      throw new CardinalityException(rowSize, columnSize);
    }

    if (rowSize == 2) {
      return getQuick(0, 0) * getQuick(1, 1) - getQuick(0, 1) * getQuick(1, 0);
    } else {
      int sign = 1;
      double ret = 0;

      for (int i = 0; i < columnSize; i++) {
        Matrix minor = new DenseMatrix(rowSize - 1, columnSize - 1);
        for (int j = 1; j < rowSize; j++) {
          boolean flag = false; /* column offset flag */
          for (int k = 0; k < columnSize; k++) {
            if (k == i) {
              flag = true;
              continue;
            }
            minor.set(j - 1, flag ? k - 1 : k, getQuick(j, k));
          }
        }
        ret += getQuick(0, i) * sign * minor.determinant();
        sign *= -1;

      }

      return ret;
    }

  }

  @Override
  public Matrix clone() {
    AbstractMatrix clone;
    try {
      clone = (AbstractMatrix) super.clone();
    } catch (CloneNotSupportedException cnse) {
      throw new IllegalStateException(cnse); // can't happen
    }
    if (rowLabelBindings != null) {
      clone.rowLabelBindings = Maps.newHashMap(rowLabelBindings);
    }
    if (columnLabelBindings != null) {
      clone.columnLabelBindings = Maps.newHashMap(columnLabelBindings);
    }
    return clone;
  }

  @Override
  public Matrix divide(double x) {
    Matrix result = like();
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result.setQuick(row, col, getQuick(row, col) / x);
      }
    }
    return result;
  }

  @Override
  public double get(int row, int column) {
    int[] c = size();
    if (row < 0 || row >= c[ROW]) {
      throw new IndexException(row, c[ROW]);
    }
    if (column < 0 || column >= c[COL]) {
      throw new IndexException(column, c[COL]);
    }
    return getQuick(row, column);
  }

  @Override
  public Matrix minus(Matrix other) {
    int[] c = size();
    int[] o = other.size();
    if (c[ROW] != o[ROW]) {
      throw new CardinalityException(c[ROW], o[ROW]);
    }
    if (c[COL] != o[COL]) {
      throw new CardinalityException(c[COL], o[COL]);
    }
    Matrix result = like();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result.setQuick(row, col, getQuick(row, col)
            - other.getQuick(row, col));
      }
    }
    return result;
  }

  @Override
  public Matrix plus(double x) {
    Matrix result = like();
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result.setQuick(row, col, getQuick(row, col) + x);
      }
    }
    return result;
  }

  @Override
  public Matrix plus(Matrix other) {
    int[] c = size();
    int[] o = other.size();
    if (c[ROW] != o[ROW]) {
      throw new CardinalityException(c[ROW], o[ROW]);
    }
    if (c[COL] != o[COL]) {
      throw new CardinalityException(c[COL], o[COL]);
    }
    Matrix result = like();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result.setQuick(row, col, getQuick(row, col)
            + other.getQuick(row, col));
      }
    }
    return result;
  }

  @Override
  public void set(int row, int column, double value) {
    int[] c = size();
    if (row < 0 || row >= c[ROW]) {
      throw new IndexException(row, c[ROW]);
    }
    if (column < 0 || column >= c[COL]) {
      throw new IndexException(column, c[COL]);
    }
    setQuick(row, column, value);
  }

  @Override
  public void set(int row, double[] data) {
    int[] c = size();
    if (c[COL] < data.length) {
      throw new CardinalityException(c[COL], data.length);
    }
    if (row < 0 || row >= c[ROW]) {
      throw new IndexException(row, c[ROW]);
    }

    for (int i = 0; i < c[COL]; i++) {
      setQuick(row, i, data[i]);
    }
  }

  @Override
  public Matrix times(double x) {
    Matrix result = like();
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result.setQuick(row, col, getQuick(row, col) * x);
      }
    }
    return result;
  }

  @Override
  public Matrix times(Matrix other) {
    int[] c = size();
    int[] o = other.size();
    if (c[COL] != o[ROW]) {
      throw new CardinalityException(c[COL], o[ROW]);
    }
    Matrix result = like(c[ROW], o[COL]);
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < o[COL]; col++) {
        double sum = 0;
        for (int k = 0; k < c[COL]; k++) {
          sum += getQuick(row, k) * other.getQuick(k, col);
        }
        result.setQuick(row, col, sum);
      }
    }
    return result;
  }

  @Override
  public Vector times(Vector v) {
    int[] c = size();
    if (c[COL] != v.size()) {
      throw new CardinalityException(c[COL], v.size());
    }
    Vector w = new DenseVector(c[ROW]);
    for (int i = 0; i < c[ROW]; i++) {
      w.setQuick(i, v.dot(getRow(i)));
    }
    return w;
  }

  @Override
  public Vector timesSquared(Vector v) {
    int[] c = size();
    if (c[COL] != v.size()) {
      throw new CardinalityException(c[COL], v.size());
    }
    Vector w = new DenseVector(c[COL]);
    for (int i = 0; i < c[ROW]; i++) {
      Vector xi = getRow(i);
      double d = xi.dot(v);
      if (d != 0.0) {
        w.assign(xi, new PlusMult(d));
      }

    }
    return w;
  }

  @Override
  public Matrix transpose() {
    int[] card = size();
    Matrix result = like(card[COL], card[ROW]);
    for (int row = 0; row < card[ROW]; row++) {
      for (int col = 0; col < card[COL]; col++) {
        result.setQuick(col, row, getQuick(row, col));
      }
    }
    return result;
  }

  @Override
  public Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested) {
    return viewPart(new int[]{rowOffset, columnOffset}, new int[]{rowsRequested, columnsRequested});
  }

  @Override
  public double zSum() {
    double result = 0;
    int[] c = size();
    for (int row = 0; row < c[ROW]; row++) {
      for (int col = 0; col < c[COL]; col++) {
        result += getQuick(row, col);
      }
    }
    return result;
  }

  protected class TransposeViewVector extends AbstractVector {

    private final Matrix matrix;
    private final int transposeOffset;
    private final int numCols;
    private final boolean rowToColumn;

    protected TransposeViewVector(Matrix m, int offset) {
      this(m, offset, true);
    }

    protected TransposeViewVector(Matrix m, int offset, boolean rowToColumn) {
      super(rowToColumn ? m.numRows() : m.numCols());
      matrix = m;
      this.transposeOffset = offset;
      this.rowToColumn = rowToColumn;
      numCols = rowToColumn ? m.numCols() : m.numRows();
    }

    @Override
    public Vector clone() {
      Vector v = new DenseVector(size());
      addTo(v);
      return v;
    }

    @Override
    public boolean isDense() {
      return true;
    }

    @Override
    public boolean isSequentialAccess() {
      return true;
    }

    @Override
    protected Matrix matrixLike(int rows, int columns) {
      return matrix.like(rows, columns);
    }

    @Override
    public Iterator<Element> iterator() {
      return new Iterator<Element>() {
        private int i;
        @Override
        public boolean hasNext() {
          return i < size();
        }

        @Override
        public Element next() {
          if (i >= size()) {
            throw new NoSuchElementException();
          }
          return getElement(i++);
        }

        @Override
        public void remove() {
          throw new UnsupportedOperationException("Element removal not supported");
        }
      };
    }

    /**
     * Currently delegates to {@link #iterator()}.
     * TODO: This could be optimized to at least skip empty rows if there are many of them.
     * @return an iterator (currently dense).
     */
    @Override
    public Iterator<Element> iterateNonZero() {
      return iterator();
    }

    @Override
    public Element getElement(final int i) {
      return new Element() {
        @Override
        public double get() {
          return getQuick(i);
        }

        @Override
        public int index() {
          return i;
        }

        @Override
        public void set(double value) {
          setQuick(i, value);
        }
      };
    }

    @Override
    public double getQuick(int index) {
      Vector v = rowToColumn ? matrix.getRow(index) : matrix.getColumn(index);
      return v == null ? 0 : v.getQuick(transposeOffset);
    }

    @Override
    public void setQuick(int index, double value) {
      Vector v = rowToColumn ? matrix.getRow(index) : matrix.getColumn(index);
      if (v == null) {
        v = newVector(numCols);
        matrix.assignRow(index, v);
      }
      v.setQuick(transposeOffset, value);
    }

    protected Vector newVector(int cardinality) {
      return new DenseVector(cardinality);
    }

    @Override
    public Vector like() {
      return new DenseVector(size());
    }

    public Vector like(int cardinality) {
      return new DenseVector(cardinality);
    }

    /**
     * TODO: currently I don't know of an efficient way to getVector this value correctly.
     *
     * @return the number of nonzero entries
     */
    @Override
    public int getNumNondefaultElements() {
      return size();
    }
  }

}
TOP

Related Classes of org.apache.mahout.math.AbstractMatrix$TransposeViewVector

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.