154 lines
6.9 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.checkutil.CheckUtil;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import static org.junit.Assert.assertTrue;
/**
* Tests comparing Nd4j ops to other libraries
*/
@RunWith(Parameterized.class)
public class Nd4jTestsComparisonC extends BaseNd4jTest {
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
public static final int SEED = 123;
DataType initialType;
public Nd4jTestsComparisonC(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@Before
public void before() throws Exception {
super.before();
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
}
@After
public void after() throws Exception {
super.after();
DataTypeUtil.setDTypeForContext(initialType);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testGemmWithOpsCommonsMath() {
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED, DataType.DOUBLE);
double[] alpha = {1.0, -0.5, 2.5};
double[] beta = {0.0, -0.25, 1.5};
INDArray cOrig = Nd4j.linspace(1, 12, 12 ).reshape(3, 4);
for (int i = 0; i < first.size(); i++) {
for (int j = 0; j < second.size(); j++) {
for (int k = 0; k < alpha.length; k++) {
for (int m = 0; m < beta.length; m++) {
INDArray cff = Nd4j.create(cOrig.shape(), 'f');
cff.assign(cOrig);
INDArray cft = Nd4j.create(cOrig.shape(), 'f');
cft.assign(cOrig);
INDArray ctf = Nd4j.create(cOrig.shape(), 'f');
ctf.assign(cOrig);
INDArray ctt = Nd4j.create(cOrig.shape(), 'f');
ctt.assign(cOrig);
double a = alpha[k];
double b = beta[k];
Pair<INDArray, String> p1 = first.get(i);
Pair<INDArray, String> p1T = firstT.get(i);
Pair<INDArray, String> p2 = second.get(j);
Pair<INDArray, String> p2T = secondT.get(j);
String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2);
String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T);
String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2);
String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T);
//System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m)));
assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a,
b, 1e-4, 1e-6));
assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a,
b, 1e-4, 1e-6));
//Also: Confirm that if the C array is uninitialized and beta is 0.0, we don't have issues like 0*NaN = NaN
if (b == 0.0) {
cff.assign(Double.NaN);
cft.assign(Double.NaN);
ctf.assign(Double.NaN);
ctt.assign(Double.NaN);
assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false,
a, b, 1e-4, 1e-6));
assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true,
a, b, 1e-4, 1e-6));
assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false,
a, b, 1e-4, 1e-6));
assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true,
a, b, 1e-4, 1e-6));
}
}
}
}
}
}
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
Pair<INDArray, String> second) {
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
}
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
}
}