154 lines
6.9 KiB
Java
154 lines
6.9 KiB
Java
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
package org.nd4j.linalg;
|
||
|
|
||
|
import org.junit.After;
|
||
|
import org.junit.Before;
|
||
|
import org.junit.Test;
|
||
|
import org.junit.runner.RunWith;
|
||
|
import org.junit.runners.Parameterized;
|
||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||
|
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||
|
import org.nd4j.linalg.checkutil.CheckUtil;
|
||
|
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||
|
import org.nd4j.linalg.factory.Nd4j;
|
||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||
|
import org.nd4j.linalg.primitives.Pair;
|
||
|
import org.slf4j.Logger;
|
||
|
import org.slf4j.LoggerFactory;
|
||
|
|
||
|
import java.util.List;
|
||
|
|
||
|
import static org.junit.Assert.assertTrue;
|
||
|
|
||
|
|
||
|
|
||
|
/**
|
||
|
* Tests comparing Nd4j ops to other libraries
|
||
|
*/
|
||
|
@RunWith(Parameterized.class)
|
||
|
public class Nd4jTestsComparisonC extends BaseNd4jTest {
|
||
|
private static Logger log = LoggerFactory.getLogger(Nd4jTestsComparisonC.class);
|
||
|
|
||
|
public static final int SEED = 123;
|
||
|
|
||
|
DataType initialType;
|
||
|
|
||
|
public Nd4jTestsComparisonC(Nd4jBackend backend) {
|
||
|
super(backend);
|
||
|
this.initialType = Nd4j.dataType();
|
||
|
}
|
||
|
|
||
|
|
||
|
@Before
|
||
|
public void before() throws Exception {
|
||
|
super.before();
|
||
|
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||
|
}
|
||
|
|
||
|
@After
|
||
|
public void after() throws Exception {
|
||
|
super.after();
|
||
|
DataTypeUtil.setDTypeForContext(initialType);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public char ordering() {
|
||
|
return 'c';
|
||
|
}
|
||
|
|
||
|
|
||
|
@Test
|
||
|
public void testGemmWithOpsCommonsMath() {
|
||
|
List<Pair<INDArray, String>> first = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 5, SEED, DataType.DOUBLE);
|
||
|
List<Pair<INDArray, String>> firstT = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 3, SEED, DataType.DOUBLE);
|
||
|
List<Pair<INDArray, String>> second = NDArrayCreationUtil.getAllTestMatricesWithShape(5, 4, SEED, DataType.DOUBLE);
|
||
|
List<Pair<INDArray, String>> secondT = NDArrayCreationUtil.getAllTestMatricesWithShape(4, 5, SEED, DataType.DOUBLE);
|
||
|
double[] alpha = {1.0, -0.5, 2.5};
|
||
|
double[] beta = {0.0, -0.25, 1.5};
|
||
|
INDArray cOrig = Nd4j.linspace(1, 12, 12 ).reshape(3, 4);
|
||
|
|
||
|
for (int i = 0; i < first.size(); i++) {
|
||
|
for (int j = 0; j < second.size(); j++) {
|
||
|
for (int k = 0; k < alpha.length; k++) {
|
||
|
for (int m = 0; m < beta.length; m++) {
|
||
|
INDArray cff = Nd4j.create(cOrig.shape(), 'f');
|
||
|
cff.assign(cOrig);
|
||
|
INDArray cft = Nd4j.create(cOrig.shape(), 'f');
|
||
|
cft.assign(cOrig);
|
||
|
INDArray ctf = Nd4j.create(cOrig.shape(), 'f');
|
||
|
ctf.assign(cOrig);
|
||
|
INDArray ctt = Nd4j.create(cOrig.shape(), 'f');
|
||
|
ctt.assign(cOrig);
|
||
|
|
||
|
double a = alpha[k];
|
||
|
double b = beta[k];
|
||
|
Pair<INDArray, String> p1 = first.get(i);
|
||
|
Pair<INDArray, String> p1T = firstT.get(i);
|
||
|
Pair<INDArray, String> p2 = second.get(j);
|
||
|
Pair<INDArray, String> p2T = secondT.get(j);
|
||
|
String errorMsgff = getGemmErrorMsg(i, j, false, false, a, b, p1, p2);
|
||
|
String errorMsgft = getGemmErrorMsg(i, j, false, true, a, b, p1, p2T);
|
||
|
String errorMsgtf = getGemmErrorMsg(i, j, true, false, a, b, p1T, p2);
|
||
|
String errorMsgtt = getGemmErrorMsg(i, j, true, true, a, b, p1T, p2T);
|
||
|
//System.out.println((String.format("Running iteration %d %d %d %d", i, j, k, m)));
|
||
|
assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false, a,
|
||
|
b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true, a,
|
||
|
b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false, a,
|
||
|
b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true, a,
|
||
|
b, 1e-4, 1e-6));
|
||
|
|
||
|
//Also: Confirm that if the C array is uninitialized and beta is 0.0, we don't have issues like 0*NaN = NaN
|
||
|
if (b == 0.0) {
|
||
|
cff.assign(Double.NaN);
|
||
|
cft.assign(Double.NaN);
|
||
|
ctf.assign(Double.NaN);
|
||
|
ctt.assign(Double.NaN);
|
||
|
|
||
|
assertTrue(errorMsgff, CheckUtil.checkGemm(p1.getFirst(), p2.getFirst(), cff, false, false,
|
||
|
a, b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgft, CheckUtil.checkGemm(p1.getFirst(), p2T.getFirst(), cft, false, true,
|
||
|
a, b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgtf, CheckUtil.checkGemm(p1T.getFirst(), p2.getFirst(), ctf, true, false,
|
||
|
a, b, 1e-4, 1e-6));
|
||
|
assertTrue(errorMsgtt, CheckUtil.checkGemm(p1T.getFirst(), p2T.getFirst(), ctt, true, true,
|
||
|
a, b, 1e-4, 1e-6));
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
private static String getTestWithOpsErrorMsg(int i, int j, String op, Pair<INDArray, String> first,
|
||
|
Pair<INDArray, String> second) {
|
||
|
return i + "," + j + " - " + first.getSecond() + "." + op + "(" + second.getSecond() + ")";
|
||
|
}
|
||
|
|
||
|
private static String getGemmErrorMsg(int i, int j, boolean transposeA, boolean transposeB, double alpha,
|
||
|
double beta, Pair<INDArray, String> first, Pair<INDArray, String> second) {
|
||
|
return i + "," + j + " - gemm(tA=" + transposeA + ",tB=" + transposeB + ",alpha=" + alpha + ",beta=" + beta
|
||
|
+ "). A=" + first.getSecond() + ", B=" + second.getSecond();
|
||
|
}
|
||
|
}
|