package mikera.matrixx.impl;
import mikera.arrayz.ISparse;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.ASizedVector;
import mikera.vectorz.util.VectorzException;
/**
* Abstract base class for banded matrices
*
* Banded matrix implementations are assumed to store their data efficiently in diagonal bands,
* so functions on banded matrices can be designed to exploit this fact.
*
* May be either square or rectangular
*
* @author Mike
*
*/
public abstract class ABandedMatrix extends AMatrix implements ISparse, IFastBands {
private static final long serialVersionUID = -229314208418131186L;
@Override
public abstract int upperBandwidthLimit();
@Override
public abstract int lowerBandwidthLimit();
@Override
public abstract AVector getBand(int band);
@Override
public int upperBandwidth() {
for (int i=upperBandwidthLimit(); i>0; i--) {
if (!(getBand(i).isZero())) return i;
}
return 0;
}
@Override
public int lowerBandwidth() {
for (int i=-lowerBandwidthLimit(); i<0; i++) {
if (!(getBand(i).isZero())) return -i;
}
return 0;
}
@Override
public boolean isMutable() {
int lb=lowerBandwidthLimit(), ub=upperBandwidthLimit();
for (int i=-lb; i<=ub; i++) {
if (getBand(i).isMutable()) return true;
}
return false;
}
@Override
public boolean isFullyMutable() {
return false;
}
@Override
public boolean isSymmetric() {
if (rowCount()!=columnCount()) return false;
int bs=Math.max(upperBandwidthLimit(), lowerBandwidthLimit());
for (int i=1; i<=bs; i++) {
if (!getBand(i).equals(getBand(-i))) return false;
}
return true;
}
@Override
public boolean isUpperTriangular() {
return (lowerBandwidthLimit()==0)||(lowerBandwidth()==0);
}
@Override
public boolean isLowerTriangular() {
return (upperBandwidthLimit()==0)||(upperBandwidth()==0);
}
@Override
public AVector getRow(int row) {
return new BandedMatrixRow(row);
}
@Override
public long nonZeroCount() {
long t=0;
for (int i=-lowerBandwidthLimit(); i<=upperBandwidthLimit(); i++) {
t+=getBand(i).nonZeroCount();
}
return t;
}
@Override
public double elementSum() {
double t=0;
for (int i=-lowerBandwidthLimit(); i<=upperBandwidthLimit(); i++) {
t+=getBand(i).elementSum();
}
return t;
}
@Override
public double trace() {
return getBand(0).elementSum();
}
@Override
public double diagonalProduct() {
return getBand(0).elementProduct();
}
@Override
public double elementSquaredSum() {
double t=0;
for (int i=-lowerBandwidthLimit(); i<=upperBandwidthLimit(); i++) {
t+=getBand(i).elementSquaredSum();
}
return t;
}
@Override
public void fill(double value) {
for (int i=-rowCount()+1; i<columnCount(); i++) {
getBand(i).fill(value);
}
}
@Override
public Matrix toMatrix() {
int rc = rowCount();
int cc = columnCount();
Matrix m = Matrix.create(rc, cc);
for (int i=-lowerBandwidthLimit(); i<=upperBandwidthLimit(); i++) {
m.getBand(i).set(this.getBand(i));
}
return m;
}
@Override
public Matrix toMatrixTranspose() {
int rc = rowCount();
int cc = columnCount();
Matrix m = Matrix.create(cc, rc);
for (int i=-lowerBandwidthLimit(); i<=upperBandwidthLimit(); i++) {
m.getBand(-i).set(this.getBand(i));
}
return m;
}
/**
* Inner class for generic banded matrix rows
* @author Mike
*
*/
@SuppressWarnings("serial")
private final class BandedMatrixRow extends ASizedVector {
final int row;
final int lower;
final int upper;
public BandedMatrixRow(int row) {
super(columnCount());
this.row=row;
this.lower=-lowerBandwidthLimit();
this.upper=upperBandwidthLimit();
}
@Override
public double get(int i) {
checkIndex(i);
return unsafeGet(i);
}
@Override
public double unsafeGet(int i) {
int b=i-row;
if ((b<lower)||(b>upper)) return 0;
return getBand(b).unsafeGet(Math.min(i, row));
}
@Override
public double dotProduct(AVector v) {
double result=0.0;
for (int i=Math.max(0,lower+row); i<=Math.min(length-1, row+upper);i++) {
result+=getBand(i-row).unsafeGet(Math.min(i, row))*v.unsafeGet(i);
}
return result;
}
@Override
public double dotProduct(Vector v) {
double result=0.0;
for (int i=Math.max(0,lower+row); i<=Math.min(length-1, row+upper);i++) {
result+=getBand(i-row).unsafeGet(Math.min(i, row))*v.unsafeGet(i);
}
return result;
}
@Override
public void set(int i, double value) {
checkIndex(i);
unsafeSet(i,value);
}
@Override
public void unsafeSet(int i, double value) {
int b=i-row;
getBand(b).unsafeSet(Math.min(i, row),value);
}
@Override
public AVector exactClone() {
return ABandedMatrix.this.exactClone().getRow(row);
}
@Override
public boolean isFullyMutable() {
return ABandedMatrix.this.isFullyMutable();
}
@Override
public double dotProduct(double[] data, int offset) {
double result=0.0;
for (int i=0; i<length; i++) {
result+=data[offset+i]*unsafeGet(i);
}
return result;
}
}
@Override
public void addToArray(double[] data, int offset) {
int b1=-lowerBandwidth();
int b2=upperBandwidth();
int cc=columnCount();
for (int b=b1; b<=b2; b++) {
AVector band=getBand(b);
int di = offset+this.bandStartColumn(b)+cc*bandStartRow(b);
band.addToArray(data, di, cc+1);
}
}
@Override
public double[] toDoubleArray() {
double[] result=Matrix.createStorage(rowCount(),columnCount());
// since this array is sparse, fastest to use addToArray to modify only non-zero elements
addToArray(result,0);
return result;
}
@Override public void validate() {
super.validate();
if (lowerBandwidthLimit()<0) throw new VectorzException("Negative lower bandwidth limit?!?");
int minBand=-lowerBandwidthLimit();
int maxBand=upperBandwidthLimit();
if (minBand<=-rowCount()) throw new VectorzException("Invalid lower limit: "+minBand);
if (maxBand>=columnCount()) throw new VectorzException("Invalid upper limit: "+maxBand);
for (int i=minBand; i<=maxBand; i++) {
AVector v=getBand(i);
if (bandLength(i)!=v.length()) throw new VectorzException("Invalid band length: "+i);
}
}
@Override
public double density() {
return nonZeroCount()/((double)elementCount());
}
}