cavis/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsComparisonFortran.java
Alex Black ddf70ac450
Avoid double printing of start/stop test in a few cases (#210)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-02-03 22:18:01 +11:00

257 lines
11 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.CheckUtil;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
/**
* Tests comparing Nd4j ops to other libraries
*/
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonFortran extends BaseNd4jTest {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonFortran.class);
public static final int SEED = 123;
DataType initialType;
public Nd4jTestsComparisonFortran(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@Before
public void before() throws Exception {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(SEED);
}
@After
public void after() throws Exception {
DataTypeUtil.setDTypeForContext(initialType);
}
@Override
public char ordering() {
return 'f';
}
@Test
public void testCrash() {
INDArray array3d = Nd4j.ones(1, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 0);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array3d, 1);
INDArray array4d = Nd4j.ones(1, 10, 10, 10);
Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(array4d, 0);
}
@Test
public void testMmulWithOpsCommonsMath() {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
for (int j = 0; j < second.size(); j++) {
Pair<INDArray, String> p1 = first.get(i);
Pair<INDArray, String> p2 = second.get(j);
String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);
assertTrue(errorMsg, CheckUtil.checkMmul(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
}
}
}
@Test
public void testGemmWithOpsCommonsMath() {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED, DataType.DOUBLE);
double[] alpha = {1.0, -0.5, 2.5};
double[] beta = {0.0, -0.25, 1.5};
INDArray cOrig = Nd4j.create(new int[] {3, 4});
Random r = new Random(12345);
for (int i = 0; i < cOrig.size(0); i++) {
for (int j = 0; j < cOrig.size(1); j++) {
cOrig.putScalar(new int[] {i, j}, r.nextDouble());
}
}
for (int i = 0; i < first.size(); i++) {
for (int j = 0; j < second.size(); j++) {
for (int k = 0; k < alpha.length; k++) {
for (int m = 0; m < beta.length; m++) {
//System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m)));
INDArray cff = Nd4j.create(cOrig.shape(), 'f');
cff.assign(cOrig);
INDArray cft = Nd4j.create(cOrig.shape(), 'f');
cft.assign(cOrig);
INDArray ctf = Nd4j.create(cOrig.shape(), 'f');
ctf.assign(cOrig);
INDArray ctt = Nd4j.create(cOrig.shape(), 'f');
ctt.assign(cOrig);
double a = alpha[k];
double b = beta[k];
Pair<INDArray, String> p1 = first.get(i);
Pair<INDArray, String> p1T = firstT.get(i);
Pair<INDArray, String> p2 = second.get(j);
Pair<INDArray, String> p2T = secondT.get(j);
String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2);
String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T);
String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2);
String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T);
assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a,
b, 1e-4, 1e-6));
}
}
}
}
}
@Test
public void testGemvApacheCommons() {
int[] rowsArr = new int[] {4, 4, 4, 8, 8, 8};
int[] colsArr = new int[] {2, 1, 10, 2, 1, 10};
for (int x = 0; x < rowsArr.length; x++) {
int rows = rowsArr[x];
int cols = colsArr[x];
List<Pair<INDArray, String>> matrices = NDArrayCreationUtil.getAllTestMatricesWithShape(rows, cols, 12345, DataType.DOUBLE);
List<Pair<INDArray, String>> vectors = NDArrayCreationUtil.getAllTestMatricesWithShape(cols, 1, 12345, DataType.DOUBLE);
for (int i = 0; i < matrices.size(); i++) {
for (int j = 0; j < vectors.size(); j++) {
Pair<INDArray, String> p1 = matrices.get(i);
Pair<INDArray, String> p2 = vectors.get(j);
String errorMsg = getTestWithOpsErrorMsg(i, j, "mmul", p1, p2);
INDArray m = p1.getFirst();
INDArray v = p2.getFirst();
RealMatrix rm = new BlockRealMatrix(m.rows(), m.columns());
for (int r = 0; r < m.rows(); r++) {
for (int c = 0; c < m.columns(); c++) {
double d = m.getDouble(r, c);
rm.setEntry(r, c, d);
}
}
RealMatrix rv = new BlockRealMatrix(cols, 1);
for (int r = 0; r < v.rows(); r++) {
double d = v.getDouble(r, 0);
rv.setEntry(r, 0, d);
}
INDArray gemv = m.mmul(v);
RealMatrix gemv2 = rm.multiply(rv);
assertArrayEquals(new long[] {rows, 1}, gemv.shape());
assertArrayEquals(new int[] {rows, 1},
new int[] {gemv2.getRowDimension(), gemv2.getColumnDimension()});
//Check entries:
for (int r = 0; r < rows; r++) {
double exp = gemv2.getEntry(r, 0);
double act = gemv.getDouble(r, 0);
assertEquals(errorMsg, exp, act, 1e-5);
}
}
}
}
}
@Test
public void testAddSubtractWithOpsCommonsMath() {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
for (int j = 0; j < second.size(); j++) {
Pair<INDArray, String> p1 = first.get(i);
Pair<INDArray, String> p2 = second.get(j);
String errorMsg1 = getTestWithOpsErrorMsg(i, j, "add", p1, p2);
String errorMsg2 = getTestWithOpsErrorMsg(i, j, "sub", p1, p2);
boolean addFail = CheckUtil.checkAdd(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
assertTrue(errorMsg1, addFail);
boolean subFail = CheckUtil.checkSubtract(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6);
assertTrue(errorMsg2, subFail);
}
}
}
@Test
public void testMulDivOnCheckUtilMatrices() {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
for (int i = 0; i < first.size(); i++) {
for (int j = 0; j < second.size(); j++) {
Pair<INDArray, String> p1 = first.get(i);
Pair<INDArray, String> p2 = second.get(j);
String errorMsg1 = getTestWithOpsErrorMsg(i, j, "mul", p1, p2);
String errorMsg2 = getTestWithOpsErrorMsg(i, j, "div", p1, p2);
assertTrue(errorMsg1, CheckUtil.checkMulManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
assertTrue(errorMsg2, CheckUtil.checkDivManually(p1.getFirst(), p2.getFirst(), 1e-4, 1e-6));
}
}
}
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
}
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB= " + transposeB + ",alpha=" + alpha + ",beta= " + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
}
}