Package mikera.arrayz.impl

Source Code of mikera.arrayz.impl.SliceArray

package mikera.arrayz.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import mikera.arrayz.Array;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;

/**
* A general n-dimensional double array implemented as a collection of (n-1) dimensional slices
*
* @author Mike
*
* @param <T>
*/
public final class SliceArray<T extends INDArray> extends BaseShapedArray {
  private static final long serialVersionUID = -2343678749417219155L;

  private final long[] longShape;
  private final T[] slices;
 
  private SliceArray(int[] shape, T[] slices) {
    super(shape);
    this.slices=slices;
    int dims=shape.length;
    this.longShape=new long[dims];
    for (int i=0; i<dims; i++) {
      longShape[i]=shape[i];
    }
  }
 
  @SuppressWarnings("unchecked")
  public static <T extends INDArray>  SliceArray<T> create(INDArray a) {
    return new SliceArray<T>(a.getShape(),(T[]) a.toSliceArray());
  }
 
  public static <T extends INDArray>  SliceArray<T> of(T... slices) {
    return new SliceArray<T>(IntArrays.consArray(slices.length,slices[0].getShape()),slices.clone());
  }
 
  public static SliceArray<INDArray> repeat (INDArray s, int n) {
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (int i=0; i<n; i++) {
      al.add(s);
    }
    return SliceArray.create(al);
  }
 
  public static SliceArray<INDArray> create(List<INDArray> slices) {
    int slen=slices.size();
    if (slen==0) throw new VectorzException("Empty list of slices provided to SliceArray");
    INDArray[] arr=new INDArray[slen];
    return new SliceArray<INDArray>(IntArrays.consArray(slen,slices.get(0).getShape()),slices.toArray(arr));
  }
 
  public static SliceArray<INDArray> create(List<INDArray> slices, int[] shape) {
    int slen=slices.size();
    INDArray[] arr=new INDArray[slen];
    return new SliceArray<INDArray>(shape,slices.toArray(arr));
  }
 
  @Override
  public int dimensionality() {
    return shape.length;
  }
 
  @Override
  public long[] getLongShape() {
     return longShape;
  }

  @Override
  public double get(int... indexes) {
    int d=indexes.length;
    T slice=slices[indexes[0]];
    switch (d) {
      case 0: throw new VectorzException("Can't do 0D get on SliceArray!");
      case 1: return slice.get();
      case 2: return slice.get(indexes[1]);
      case 3: return slice.get(indexes[1],indexes[2]);
      default: return slice.get(Arrays.copyOfRange(indexes,1,d));
    }
  }
 
  @Override
  public void set(double value) {
    for (T s:slices) {
      s.set(value);
    }
  }
 
  public void fill(double value) {
    for (T s:slices) {
      s.fill(value);
    }
  }
 
  @Override
  public void set(int[] indexes, double value) {
    int d=indexes.length;
    if (d==0) {
      set(value);
    }
   
    T slice=slices[indexes[0]];
    switch (d) {
      case 0: throw new VectorzException("Can't do 0D set on SliceArray!");
      case 1: slice.set(value); return;
      case 2: slice.set(indexes[1],value); return;
      case 3: slice.set(indexes[1],indexes[2],value); return;
      default: slice.set(Arrays.copyOfRange(indexes,1,d),value); return;
    }
  }

  @Override
  public AVector asVector() {
    AVector v=Vector0.INSTANCE;
    for (INDArray a:slices) {
      v=v.join(a.asVector());
    }
    return v; 
  }

  @Override
  public INDArray reshape(int... dimensions) {
    return Arrayz.createFromVector(asVector(), dimensions);
  }

  @Override
  public INDArray slice(int majorSlice) {
    return slices[majorSlice];
  }
 

  @Override
  public INDArray slice(int dimension, int index) {
    if (dimension<0) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
    if (dimension==0) return slice(index);
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (INDArray s:this) {
      al.add(s.slice(dimension-1,index));
    }
    return SliceArray.create(al)
 

  @Override
  public long elementCount() {
    return IntArrays.arrayProduct(shape);
  }
 
  public INDArray innerProduct(INDArray a) {
    int dims=dimensionality();
    switch (dims) {
      case 0: {
        a=a.clone();
        a.scale(get());
        return a;
      }
      case 1: {
        return this.toVector().innerProduct(a);
      }
      case 2: {
        return Matrix.create(this).innerProduct(a);
      }
    }
   
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (INDArray s:this) {
      al.add(s.innerProduct(a));
    }
    return Arrayz.create(al);
  }
 
  public INDArray outerProduct(INDArray a) {
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (INDArray s:this) {
      al.add(s.outerProduct(a));
    }
    return Arrayz.create(al);
  }

  @Override
  public boolean isMutable() {
    for (INDArray a:slices) {
      if (a.isMutable()) return true;
    }
    return false;
  }

  @Override
  public boolean isFullyMutable() {
    for (INDArray a:slices) {
      if (!a.isFullyMutable()) return false;
    }
    return true;
  }
 
  @Override
  public boolean isZero() {
    for (INDArray a:slices) {
      if (!a.isZero()) return false;
    }
    return true;
  }

 
  @Override
  public boolean isBoolean() {
    for (INDArray a:slices) {
      if (!a.isBoolean()) return false;
    }
    return true;
  }

  @Override
  public boolean isElementConstrained() {
    for (INDArray a:slices) {
      if (a.isElementConstrained()) return true;
    }
    return false;
  }

  @Override
  public boolean isView() {
    return true;
  }

  @Override
  public void applyOp(Op op) {
    for (INDArray a:slices) {
      a.applyOp(op);
    }
  }

  @Override
  public void applyOp(IOperator op) {
    for (INDArray a:slices) {
      a.applyOp(op);
    }
  }
 
  @Override
  public void multiply(double d) {
    for (INDArray a:slices) {
      a.scale(d);
    }
  }

  @Override
  public boolean equals(INDArray a) {
    if (!Arrays.equals(a.getShape(), this.getShape())) return false;
    for (int i=0; i<slices.length; i++) {
      if (!slices[i].equals(a.slice(i))) return false;
    }
    return true;
  }

  @SuppressWarnings("unchecked")
  @Override
  public SliceArray<T> exactClone() {
    T[] newSlices=slices.clone();
    for (int i=0; i<slices.length; i++) {
      newSlices[i]=(T) newSlices[i].exactClone();
    }
    return new SliceArray<T>(shape,newSlices);
  }

  @Override
  public List<?> getSlices() {
    ArrayList<Object> al=new ArrayList<Object>();
    if (dimensionality()==1) {
      int n=sliceCount();
      for (int i=0; i<n ; i++) {
        al.add(Double.valueOf(slices[i].get()));
      }
      return al;
    }
    for (INDArray sl:this) {
      al.add(sl);
    }
    return al;
  }
 
  @Override
  public INDArray[] toSliceArray() {
    int n=sliceCount();
    INDArray[] al=new INDArray[n];
    for (int i=0; i<n; i++) {
      al[i]=slice(i);
    }
    return al;
  }
 
  @Override
  public double[] toDoubleArray() {
    double[] result=Array.createStorage(this.getShape());
    int skip=(int)slice(0).elementCount();
    for (int i=0; i<slices.length; i++) {
      INDArray s=slices[i];
      if (s.isSparse()) {
        s.addToArray(result,skip*i);       
      } else {
        s.getElements(result,skip*i);
      }
    }
    return result;
  }
 
  @Override
  public boolean equalsArray(double[] values, int offset) {
    int skip=(int)slice(0).elementCount();
    int di=offset;
    for (int i=0; i<slices.length; i++) {
      if (!slices[i].equalsArray(values,di)) return false;
      di+=skip;
    }
    return true;
  }

  @Override
  public void validate() {
    if (shape.length!=longShape.length) throw new VectorzException("Shape mismatch");
   
    long ec=0;
    for (int i=0; i<slices.length; i++) {
      T s=slices[i];
      ec+=s.elementCount();
      slices[i].validate();
      int[] ss=s.getShape();
      for (int j=0; j<ss.length; j++) {
        if(getShape(j+1)!=ss[j]) throw new VectorzException("Slice shape mismatch");
      }
    }
    if (ec!=elementCount()) throw new VectorzException("Element count mismatch");
   
    super.validate();
  }

  @Override
  public double get() {
    throw new IllegalArgumentException("0d get not supported on "+getClass());
  }

  @Override
  public double get(int x) {
    return slices[x].get();
  }

  @Override
  public double get(int x, int y) {
    return slices[x].get(y);
  }
}
TOP

Related Classes of mikera.arrayz.impl.SliceArray

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.