/*
* Copyright (c) 2009-2012, Peter Abeles. All Rights Reserved.
*
* This file is part of Efficient Java Matrix Library (EJML).
*
* EJML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* EJML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with EJML. If not, see <http://www.gnu.org/licenses/>.
*/
package org.ejml.alg.block;
import org.ejml.data.BlockMatrix64F;
import org.ejml.data.D1Submatrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.simple.SimpleMatrix;
import org.junit.Test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
/**
* @author Peter Abeles
*/
public class TestBlockMultiplication {
private static Random rand = new Random(234234);
private static final int BLOCK_LENGTH = 4;
private static final int numRows = 10;
private static final int numCols = 13;
/**
* Checks to see if matrix multiplication variants handles submatrices correctly
*/
@Test
public void mult_submatrix() {
Method methods[] = BlockMultiplication.class.getDeclaredMethods();
int numFound = 0;
for( Method m : methods) {
String name = m.getName();
if( name.contains("Block") || !name.contains("mult") )
continue;
// System.out.println("name = "+name);
boolean transA = name.contains("TransA");
boolean transB = name.contains("TransB");
int operationType = 0;
if( name.contains("Plus")) operationType = 1;
else if ( name.contains("Minus")) operationType = -1;
else if( name.contains("Set")) operationType = 0;
checkMult_submatrix(m,operationType,transA,transB);
numFound++;
}
// make sure all the functions were in fact tested
assertEquals(7,numFound);
}
private static void checkMult_submatrix( Method func , int operationType , boolean transA , boolean transB )
{
// the submatrix is the same size as the originals
checkMult_submatrix( func , operationType , transA , transB , sub(0,0,numRows,numCols),sub(0,0,numCols,numRows));
// submatrix has a size in multiples of the block
checkMult_submatrix( func , operationType , transA , transB , sub(BLOCK_LENGTH, BLOCK_LENGTH, BLOCK_LENGTH *2, BLOCK_LENGTH *2),
sub(BLOCK_LENGTH, BLOCK_LENGTH, BLOCK_LENGTH *2, BLOCK_LENGTH *2));
// submatrix row and column ends at a fraction of a block
checkMult_submatrix( func , operationType , transA , transB , sub(BLOCK_LENGTH, BLOCK_LENGTH,numRows,numCols),
sub(BLOCK_LENGTH, BLOCK_LENGTH,numCols,numRows));
// the previous tests have some symmetry in it which can mask errors
checkMult_submatrix( func , operationType , transA , transB , sub(0, BLOCK_LENGTH,BLOCK_LENGTH,2*BLOCK_LENGTH),
sub(0, BLOCK_LENGTH,BLOCK_LENGTH,numRows));
}
/**
* Multiplies the two sub-matrices together. Checks to see if the same result
* is found when multiplied using the normal algorithm versus the submatrix one.
*/
private static void checkMult_submatrix( Method func , int operationType , boolean transA , boolean transB ,
D1Submatrix64F A , D1Submatrix64F B ) {
if( A.col0 % BLOCK_LENGTH != 0 || A.row0 % BLOCK_LENGTH != 0)
throw new IllegalArgumentException("Submatrix A is not block aligned");
if( B.col0 % BLOCK_LENGTH != 0 || B.row0 % BLOCK_LENGTH != 0)
throw new IllegalArgumentException("Submatrix B is not block aligned");
BlockMatrix64F origA = BlockMatrixOps.createRandom(numRows,numCols,-1,1, rand, BLOCK_LENGTH);
BlockMatrix64F origB = BlockMatrixOps.createRandom(numCols,numRows,-1,1, rand, BLOCK_LENGTH);
A.original = origA;
B.original = origB;
int w = B.col1-B.col0;
int h = A.row1-A.row0;
// offset it to make the test harder
// randomize to see if its set or adding
BlockMatrix64F subC = BlockMatrixOps.createRandom(BLOCK_LENGTH +h, BLOCK_LENGTH +w, -1,1,rand, BLOCK_LENGTH);
D1Submatrix64F C = new D1Submatrix64F(subC, BLOCK_LENGTH, subC.numRows, BLOCK_LENGTH, subC.numCols);
DenseMatrix64F rmC = multByExtract(operationType,A,B,C);
if( transA ) {
origA = BlockMatrixOps.transpose(origA,null);
transposeSub(A);
A.original = origA;
}
if( transB ) {
origB = BlockMatrixOps.transpose(origB,null);
transposeSub(B);
B.original = origB;
}
try {
func.invoke(null,BLOCK_LENGTH,A,B,C);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
for( int i = C.row0; i < C.row1; i++ ) {
for( int j = C.col0; j < C.col1; j++ ) {
// System.out.println(i+" "+j);
double diff = Math.abs(subC.get(i,j) - rmC.get(i-C.row0,j-C.col0));
// System.out.println(subC.get(i,j)+" "+rmC.get(i-C.row0,j-C.col0));
if( diff >= 1e-12) {
subC.print();
rmC.print();
System.out.println(func.getName());
System.out.println("transA "+transA);
System.out.println("transB "+transB);
System.out.println("type "+operationType);
fail("Error too large");
}
}
}
}
public static void transposeSub(D1Submatrix64F A) {
int temp = A.col0;
A.col0 = A.row0;
A.row0 = temp;
temp = A.col1;
A.col1 = A.row1;
A.row1 = temp;
}
private static D1Submatrix64F sub( int row0 , int col0 , int row1 , int col1 ) {
return new D1Submatrix64F(null,row0, row1, col0, col1);
}
private static DenseMatrix64F multByExtract( int operationType ,
D1Submatrix64F subA , D1Submatrix64F subB ,
D1Submatrix64F subC )
{
SimpleMatrix A = subA.extract();
SimpleMatrix B = subB.extract();
SimpleMatrix C = subC.extract();
if( operationType > 0 )
return A.mult(B).plus(C).getMatrix();
else if( operationType < 0 )
return C.minus(A.mult(B)).getMatrix();
else
return A.mult(B).getMatrix();
}
}