Update Nd4jTestsC.java
parent
46ddd135a9
commit
459502f2bd
|
@ -134,12 +134,7 @@ import java.nio.ByteOrder;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
@ -258,7 +253,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray arr = Nd4j.rand(1, 20);
|
INDArray arr = Nd4j.rand(1, 20);
|
||||||
|
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.resolve("new-dir-" + UUID.randomUUID().toString()).toFile();
|
||||||
|
assertTrue(dir.mkdirs());
|
||||||
|
|
||||||
String outPath = FilenameUtils.concat(dir.getAbsolutePath(), "dl4jtestserialization.bin");
|
String outPath = FilenameUtils.concat(dir.getAbsolutePath(), "dl4jtestserialization.bin");
|
||||||
|
|
||||||
|
@ -1584,7 +1580,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testDivide(Nd4jBackend backend) {
|
public void testDivide(Nd4jBackend backend) {
|
||||||
INDArray two = Nd4j.create(new double[] {2, 2, 2, 2});
|
INDArray two = Nd4j.create(new double[] {2, 2, 2, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray div = two.div(two);
|
INDArray div = two.div(two);
|
||||||
assertEquals(Nd4j.ones(4), div);
|
assertEquals(Nd4j.ones(4), div);
|
||||||
|
|
||||||
|
@ -1600,7 +1596,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSigmoid(Nd4jBackend backend) {
|
public void testSigmoid(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f});
|
INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}).castTo(DataType.DOUBLE);
|
||||||
INDArray sigmoid = Transforms.sigmoid(n, false);
|
INDArray sigmoid = Transforms.sigmoid(n, false);
|
||||||
assertEquals(assertion, sigmoid);
|
assertEquals(assertion, sigmoid);
|
||||||
}
|
}
|
||||||
|
@ -1608,8 +1604,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNeg(Nd4jBackend backend) {
|
public void testNeg(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4});
|
INDArray assertion = Nd4j.create(new float[] {-1, -2, -3, -4}).castTo(DataType.DOUBLE);
|
||||||
INDArray neg = Transforms.neg(n);
|
INDArray neg = Transforms.neg(n);
|
||||||
assertEquals(assertion, neg,getFailureMessage(backend));
|
assertEquals(assertion, neg,getFailureMessage(backend));
|
||||||
|
|
||||||
|
@ -1619,14 +1615,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm2Double(Nd4jBackend backend) {
|
public void testNorm2Double(Nd4jBackend backend) {
|
||||||
DataType initialType = Nd4j.dataType();
|
DataType initialType = Nd4j.dataType();
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
|
||||||
|
|
||||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
double assertion = 5.47722557505;
|
double assertion = 5.47722557505;
|
||||||
double norm3 = n.norm2Number().doubleValue();
|
double norm3 = n.norm2Number().doubleValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new double[] {1, 2, 3, 4}, new long[] {2, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
double norm2 = row1.norm2Number().doubleValue();
|
double norm2 = row1.norm2Number().doubleValue();
|
||||||
double assertion2 = 5.0f;
|
double assertion2 = 5.0f;
|
||||||
|
@ -1639,13 +1634,13 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNorm2(Nd4jBackend backend) {
|
public void testNorm2(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
float assertion = 5.47722557505f;
|
float assertion = 5.47722557505f;
|
||||||
float norm3 = n.norm2Number().floatValue();
|
float norm3 = n.norm2Number().floatValue();
|
||||||
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
assertEquals(assertion, norm3, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
|
||||||
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray row = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray row1 = row.getRow(1);
|
INDArray row1 = row.getRow(1);
|
||||||
float norm2 = row1.norm2Number().floatValue();
|
float norm2 = row1.norm2Number().floatValue();
|
||||||
float assertion2 = 5.0f;
|
float assertion2 = 5.0f;
|
||||||
|
@ -1658,8 +1653,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testCosineSim(Nd4jBackend backend) {
|
public void testCosineSim(Nd4jBackend backend) {
|
||||||
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new double[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new double[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
double sim = Transforms.cosineSim(vec1, vec2);
|
double sim = Transforms.cosineSim(vec1, vec2);
|
||||||
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
assertEquals(1, sim, 1e-1,getFailureMessage(backend));
|
||||||
|
|
||||||
|
@ -1675,7 +1670,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScal(Nd4jBackend backend) {
|
public void testScal(Nd4jBackend backend) {
|
||||||
double assertion = 2;
|
double assertion = 2;
|
||||||
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8});
|
INDArray answer = Nd4j.create(new double[] {2, 4, 6, 8}).castTo(DataType.DOUBLE);
|
||||||
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
INDArray scal = Nd4j.getBlasWrapper().scal(assertion, answer);
|
||||||
assertEquals(answer, scal,getFailureMessage(backend));
|
assertEquals(answer, scal,getFailureMessage(backend));
|
||||||
|
|
||||||
|
@ -1691,8 +1686,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testExp(Nd4jBackend backend) {
|
public void testExp(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4});
|
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4}).castTo(DataType.DOUBLE);
|
||||||
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f});
|
INDArray assertion = Nd4j.create(new double[] {2.71828183f, 7.3890561f, 20.08553692f, 54.59815003f}).castTo(DataType.DOUBLE);
|
||||||
INDArray exped = Transforms.exp(n);
|
INDArray exped = Transforms.exp(n);
|
||||||
assertEquals(assertion, exped);
|
assertEquals(assertion, exped);
|
||||||
|
|
||||||
|
@ -1715,10 +1710,10 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testScalar(Nd4jBackend backend) {
|
public void testScalar(Nd4jBackend backend) {
|
||||||
INDArray a = Nd4j.scalar(1.0f);
|
INDArray a = Nd4j.scalar(1.0f).castTo(DataType.DOUBLE);
|
||||||
assertEquals(true, a.isScalar());
|
assertEquals(true, a.isScalar());
|
||||||
|
|
||||||
INDArray n = Nd4j.create(new float[] {1.0f}, new long[0]);
|
INDArray n = Nd4j.create(new float[] {1.0f}, new long[0]).castTo(DataType.DOUBLE);
|
||||||
assertEquals(n, a);
|
assertEquals(n, a);
|
||||||
assertTrue(n.isScalar());
|
assertTrue(n.isScalar());
|
||||||
}
|
}
|
||||||
|
@ -1768,7 +1763,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testColumns(Nd4jBackend backend) {
|
public void testColumns(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.create(new long[] {3, 2});
|
INDArray arr = Nd4j.create(new long[] {3, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray column2 = arr.getColumn(0);
|
INDArray column2 = arr.getColumn(0);
|
||||||
//assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, column2.shape()));
|
//assertEquals(true, Shape.shapeEquals(new long[]{3, 1}, column2.shape()));
|
||||||
INDArray column = Nd4j.create(new double[] {1, 2, 3}, new long[] {3});
|
INDArray column = Nd4j.create(new double[] {1, 2, 3}, new long[] {3});
|
||||||
|
@ -1902,7 +1897,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInplaceTranspose(Nd4jBackend backend) {
|
public void testInplaceTranspose(Nd4jBackend backend) {
|
||||||
INDArray test = Nd4j.rand(3, 4);
|
INDArray test = Nd4j.rand(3, 4).castTo(DataType.DOUBLE);
|
||||||
INDArray orig = test.dup();
|
INDArray orig = test.dup();
|
||||||
INDArray transposei = test.transposei();
|
INDArray transposei = test.transposei();
|
||||||
|
|
||||||
|
@ -1918,7 +1913,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testTADMMul(Nd4jBackend backend) {
|
public void testTADMMul(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
val shape = new long[] {4, 5, 7};
|
val shape = new long[] {4, 5, 7};
|
||||||
INDArray arr = Nd4j.rand(shape);
|
INDArray arr = Nd4j.rand(shape).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray tad = arr.tensorAlongDimension(0, 1, 2);
|
INDArray tad = arr.tensorAlongDimension(0, 1, 2);
|
||||||
assertArrayEquals(tad.shape(), new long[] {5, 7});
|
assertArrayEquals(tad.shape(), new long[] {5, 7});
|
||||||
|
@ -1935,7 +1930,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertTrue(tad.equals(copy));
|
assertTrue(tad.equals(copy));
|
||||||
tad = tad.reshape(7, 5);
|
tad = tad.reshape(7, 5);
|
||||||
copy = copy.reshape(7, 5);
|
copy = copy.reshape(7, 5);
|
||||||
INDArray first = Nd4j.rand(new long[] {2, 7});
|
INDArray first = Nd4j.rand(new long[] {2, 7}).castTo(DataType.DOUBLE);
|
||||||
INDArray mmul = first.mmul(tad);
|
INDArray mmul = first.mmul(tad);
|
||||||
INDArray mmulCopy = first.mmul(copy);
|
INDArray mmulCopy = first.mmul(copy);
|
||||||
|
|
||||||
|
@ -1947,14 +1942,14 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testTADMMulLeadingOne(Nd4jBackend backend) {
|
public void testTADMMulLeadingOne(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
val shape = new long[] {1, 5, 7};
|
val shape = new long[] {1, 5, 7};
|
||||||
INDArray arr = Nd4j.rand(shape);
|
INDArray arr = Nd4j.rand(shape).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray tad = arr.tensorAlongDimension(0, 1, 2);
|
INDArray tad = arr.tensorAlongDimension(0, 1, 2);
|
||||||
boolean order = Shape.cOrFortranOrder(tad.shape(), tad.stride(), 1);
|
boolean order = Shape.cOrFortranOrder(tad.shape(), tad.stride(), 1);
|
||||||
assertArrayEquals(tad.shape(), new long[] {5, 7});
|
assertArrayEquals(tad.shape(), new long[] {5, 7});
|
||||||
|
|
||||||
|
|
||||||
INDArray copy = Nd4j.zeros(5, 7);
|
INDArray copy = Nd4j.zeros(5, 7).castTo(DataType.DOUBLE);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
for (int j = 0; j < 7; j++) {
|
for (int j = 0; j < 7; j++) {
|
||||||
copy.putScalar(new long[] {i, j}, tad.getDouble(i, j));
|
copy.putScalar(new long[] {i, j}, tad.getDouble(i, j));
|
||||||
|
@ -1965,7 +1960,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
tad = tad.reshape(7, 5);
|
tad = tad.reshape(7, 5);
|
||||||
copy = copy.reshape(7, 5);
|
copy = copy.reshape(7, 5);
|
||||||
INDArray first = Nd4j.rand(new long[] {2, 7});
|
INDArray first = Nd4j.rand(new long[] {2, 7}).castTo(DataType.DOUBLE);
|
||||||
INDArray mmul = first.mmul(tad);
|
INDArray mmul = first.mmul(tad);
|
||||||
INDArray mmulCopy = first.mmul(copy);
|
INDArray mmulCopy = first.mmul(copy);
|
||||||
|
|
||||||
|
@ -1976,11 +1971,11 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSum2(Nd4jBackend backend) {
|
public void testSum2(Nd4jBackend backend) {
|
||||||
INDArray test = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2});
|
INDArray test = Nd4j.create(new float[] {1, 2, 3, 4}, new long[] {2, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray sum = test.sum(1);
|
INDArray sum = test.sum(1);
|
||||||
INDArray assertion = Nd4j.create(new float[] {3, 7});
|
INDArray assertion = Nd4j.create(new float[] {3, 7}).castTo(DataType.DOUBLE);
|
||||||
assertEquals(assertion, sum);
|
assertEquals(assertion, sum);
|
||||||
INDArray sum0 = Nd4j.create(new float[] {4, 6});
|
INDArray sum0 = Nd4j.create(new float[] {4, 6}).castTo(DataType.DOUBLE);
|
||||||
assertEquals(sum0, test.sum(0));
|
assertEquals(sum0, test.sum(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1991,7 +1986,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
int[] shape = {3, 2, 4};
|
int[] shape = {3, 2, 4};
|
||||||
INDArray arr3d = Nd4j.rand(shape);
|
INDArray arr3d = Nd4j.rand(shape).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray get0 = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1));
|
INDArray get0 = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 1));
|
||||||
INDArray getPoint0 = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0));
|
INDArray getPoint0 = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||||
|
@ -2036,7 +2031,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
int[] shape = {3, 2, 4};
|
int[] shape = {3, 2, 4};
|
||||||
INDArray arr3d = Nd4j.rand(shape);
|
INDArray arr3d = Nd4j.rand(shape).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
for (int x = 0; x < 4; x++) {
|
for (int x = 0; x < 4; x++) {
|
||||||
INDArray getInterval = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(x, x + 1)); //3d
|
INDArray getInterval = arr3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(x, x + 1)); //3d
|
||||||
|
@ -2059,19 +2054,19 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmul(Nd4jBackend backend) {
|
public void testMmul(Nd4jBackend backend) {
|
||||||
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
|
DataBuffer data = Nd4j.linspace(1, 10, 10, DataType.DOUBLE).data();
|
||||||
INDArray n = Nd4j.create(data, new long[] {1, 10});
|
INDArray n = Nd4j.create(data, new long[] {1, 10}).castTo(DataType.DOUBLE);
|
||||||
INDArray transposed = n.transpose();
|
INDArray transposed = n.transpose();
|
||||||
assertEquals(true, n.isRowVector());
|
assertEquals(true, n.isRowVector());
|
||||||
assertEquals(true, transposed.isColumnVector());
|
assertEquals(true, transposed.isColumnVector());
|
||||||
|
|
||||||
INDArray d = Nd4j.create(n.rows(), n.columns());
|
INDArray d = Nd4j.create(n.rows(), n.columns()).castTo(DataType.DOUBLE);
|
||||||
d.setData(n.data());
|
d.setData(n.data());
|
||||||
|
|
||||||
|
|
||||||
INDArray d3 = Nd4j.create(new double[] {1, 2}).reshape(2, 1);
|
INDArray d3 = Nd4j.create(new double[] {1, 2}).reshape(2, 1);
|
||||||
INDArray d4 = Nd4j.create(new double[] {3, 4}).reshape(1, 2);
|
INDArray d4 = Nd4j.create(new double[] {3, 4}).reshape(1, 2);
|
||||||
INDArray resultNDArray = d3.mmul(d4);
|
INDArray resultNDArray = d3.mmul(d4);
|
||||||
INDArray result = Nd4j.create(new double[][] {{3, 4}, {6, 8}});
|
INDArray result = Nd4j.create(new double[][] {{3, 4}, {6, 8}}).castTo(DataType.DOUBLE);
|
||||||
assertEquals(result, resultNDArray);
|
assertEquals(result, resultNDArray);
|
||||||
|
|
||||||
|
|
||||||
|
@ -2085,12 +2080,12 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
INDArray three = Nd4j.create(new double[] {3, 4});
|
INDArray three = Nd4j.create(new double[] {3, 4}).castTo(DataType.DOUBLE);
|
||||||
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
INDArray test = Nd4j.create(Nd4j.linspace(1, 30, 30, DataType.DOUBLE).data(), new long[] {3, 5, 2});
|
||||||
INDArray sliceRow = test.slice(0).getRow(1);
|
INDArray sliceRow = test.slice(0).getRow(1);
|
||||||
assertEquals(three, sliceRow,getFailureMessage(backend));
|
assertEquals(three, sliceRow,getFailureMessage(backend));
|
||||||
|
|
||||||
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1});
|
INDArray twoSix = Nd4j.create(new double[] {2, 6}, new long[] {2, 1}).castTo(DataType.DOUBLE);
|
||||||
INDArray threeTwoSix = three.mmul(twoSix);
|
INDArray threeTwoSix = three.mmul(twoSix);
|
||||||
|
|
||||||
INDArray sliceRowTwoSix = sliceRow.mmul(twoSix);
|
INDArray sliceRowTwoSix = sliceRow.mmul(twoSix);
|
||||||
|
@ -2109,7 +2104,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
154, 165, 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180, 0, 13, 26, 39,
|
154, 165, 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180, 0, 13, 26, 39,
|
||||||
52, 65, 78, 91, 104, 117, 130, 143, 156, 169, 182, 195, 0, 14, 28, 42, 56, 70, 84, 98, 112, 126,
|
52, 65, 78, 91, 104, 117, 130, 143, 156, 169, 182, 195, 0, 14, 28, 42, 56, 70, 84, 98, 112, 126,
|
||||||
140, 154, 168, 182, 196, 210, 0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180, 195, 210,
|
140, 154, 168, 182, 196, 210, 0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180, 195, 210,
|
||||||
225}, new long[] {16, 16});
|
225}, new long[] {16, 16}).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
|
|
||||||
INDArray n1 = Nd4j.create(Nd4j.linspace(0, 15, 16, DataType.DOUBLE).data(), new long[] {1, 16});
|
INDArray n1 = Nd4j.create(Nd4j.linspace(0, 15, 16, DataType.DOUBLE).data(), new long[] {1, 16});
|
||||||
|
@ -2167,7 +2162,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testLogX1(Nd4jBackend backend) {
|
public void testLogX1(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.create(10).assign(7);
|
INDArray x = Nd4j.create(10).assign(7).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray logX5 = Transforms.log(x, 5, true);
|
INDArray logX5 = Transforms.log(x, 5, true);
|
||||||
|
|
||||||
|
@ -2179,7 +2174,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testAddMatrix(Nd4jBackend backend) {
|
public void testAddMatrix(Nd4jBackend backend) {
|
||||||
INDArray five = Nd4j.ones(5);
|
INDArray five = Nd4j.ones(5).castTo(DataType.DOUBLE);
|
||||||
five.addi(five);
|
five.addi(five);
|
||||||
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
INDArray twos = Nd4j.valueArrayOf(5, 2);
|
||||||
assertEquals(twos, five);
|
assertEquals(twos, five);
|
||||||
|
@ -2201,7 +2196,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
public void testRowVectorMultipleIndices(Nd4jBackend backend) {
|
||||||
INDArray linear = Nd4j.create(1, 4);
|
INDArray linear = Nd4j.create(1, 4).castTo(DataType.DOUBLE);
|
||||||
linear.putScalar(new long[] {0, 1}, 1);
|
linear.putScalar(new long[] {0, 1}, 1);
|
||||||
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
assertEquals(linear.getDouble(0, 1), 1, 1e-1);
|
||||||
}
|
}
|
||||||
|
@ -2210,7 +2205,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testSize(Nd4jBackend backend) {
|
public void testSize(Nd4jBackend backend) {
|
||||||
assertThrows(IllegalArgumentException.class,() -> {
|
assertThrows(IllegalArgumentException.class,() -> {
|
||||||
INDArray arr = Nd4j.create(4, 5);
|
INDArray arr = Nd4j.create(4, 5).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < 6; i++) {
|
||||||
//This should fail for i >= 2, but doesn't
|
//This should fail for i >= 2, but doesn't
|
||||||
|
@ -2224,10 +2219,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNullPointerDataBuffer(Nd4jBackend backend) {
|
public void testNullPointerDataBuffer(Nd4jBackend backend) {
|
||||||
DataType initialType = Nd4j.dataType();
|
|
||||||
|
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
|
||||||
|
|
||||||
ByteBuffer allocate = ByteBuffer.allocateDirect(10 * 4).order(ByteOrder.nativeOrder());
|
ByteBuffer allocate = ByteBuffer.allocateDirect(10 * 4).order(ByteOrder.nativeOrder());
|
||||||
allocate.asFloatBuffer().put(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
allocate.asFloatBuffer().put(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
||||||
DataBuffer buff = Nd4j.createBuffer(allocate, DataType.FLOAT, 10);
|
DataBuffer buff = Nd4j.createBuffer(allocate, DataType.FLOAT, 10);
|
||||||
|
@ -2235,7 +2226,6 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
// System.out.println(sum);
|
// System.out.println(sum);
|
||||||
assertEquals(55f, sum, 0.001f);
|
assertEquals(55f, sum, 0.001f);
|
||||||
|
|
||||||
Nd4j.setDataType(initialType);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
|
@ -2253,8 +2243,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEps2(Nd4jBackend backend) {
|
public void testEps2(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray first = Nd4j.valueArrayOf(10, 1e-2); //0.01
|
INDArray first = Nd4j.valueArrayOf(10, 1e-2).castTo(DataType.DOUBLE); //0.01
|
||||||
INDArray second = Nd4j.zeros(10); //0.0
|
INDArray second = Nd4j.zeros(10).castTo(DataType.DOUBLE); //0.0
|
||||||
|
|
||||||
INDArray expAllZeros1 = Nd4j.getExecutioner().exec(new Eps(first, second, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f')));
|
INDArray expAllZeros1 = Nd4j.getExecutioner().exec(new Eps(first, second, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f')));
|
||||||
INDArray expAllZeros2 = Nd4j.getExecutioner().exec(new Eps(second, first, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f')));
|
INDArray expAllZeros2 = Nd4j.getExecutioner().exec(new Eps(second, first, Nd4j.create(DataType.BOOL, new long[] {1, 10}, 'f')));
|
||||||
|
@ -2336,7 +2326,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void test2DArraySlice(Nd4jBackend backend) {
|
public void test2DArraySlice(Nd4jBackend backend) {
|
||||||
INDArray array2D = Nd4j.ones(5, 7);
|
INDArray array2D = Nd4j.ones(5, 7).castTo(DataType.DOUBLE);
|
||||||
/**
|
/**
|
||||||
* This should be reverse.
|
* This should be reverse.
|
||||||
* This is compatibility with numpy.
|
* This is compatibility with numpy.
|
||||||
|
@ -2390,7 +2380,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testGetRow(Nd4jBackend backend) {
|
public void testGetRow(Nd4jBackend backend) {
|
||||||
INDArray arr = Nd4j.ones(10, 4);
|
INDArray arr = Nd4j.ones(10, 4).castTo(DataType.DOUBLE);
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
INDArray row = arr.getRow(i);
|
INDArray row = arr.getRow(i);
|
||||||
assertArrayEquals(row.shape(), new long[] {4});
|
assertArrayEquals(row.shape(), new long[] {4});
|
||||||
|
@ -2403,7 +2393,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testGetPermuteReshapeSub(Nd4jBackend backend) {
|
public void testGetPermuteReshapeSub(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
INDArray first = Nd4j.rand(new long[] {10, 4});
|
INDArray first = Nd4j.rand(new long[] {10, 4}).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
//Reshape, as per RnnOutputLayer etc on labels
|
//Reshape, as per RnnOutputLayer etc on labels
|
||||||
INDArray orig3d = Nd4j.rand(new long[] {2, 4, 15});
|
INDArray orig3d = Nd4j.rand(new long[] {2, 4, 15});
|
||||||
|
@ -2423,7 +2413,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPutAtIntervalIndexWithStride(Nd4jBackend backend) {
|
public void testPutAtIntervalIndexWithStride(Nd4jBackend backend) {
|
||||||
INDArray n1 = Nd4j.create(3, 3).assign(0.0);
|
INDArray n1 = Nd4j.create(3, 3).assign(0.0.castTo(DataType.DOUBLE));
|
||||||
INDArrayIndex[] indices = {NDArrayIndex.interval(0, 2, 3), NDArrayIndex.all()};
|
INDArrayIndex[] indices = {NDArrayIndex.interval(0, 2, 3), NDArrayIndex.all()};
|
||||||
n1.put(indices, 1);
|
n1.put(indices, 1);
|
||||||
INDArray expected = Nd4j.create(new double[][] {{1d, 1d, 1d}, {0d, 0d, 0d}, {1d, 1d, 1d}});
|
INDArray expected = Nd4j.create(new double[][] {{1d, 1d, 1d}, {0d, 0d, 0d}, {1d, 1d, 1d}});
|
||||||
|
@ -2434,11 +2424,11 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMMulMatrixTimesColVector(Nd4jBackend backend) {
|
public void testMMulMatrixTimesColVector(Nd4jBackend backend) {
|
||||||
//[1 1 1 1 1; 10 10 10 10 10; 100 100 100 100 100] x [1; 1; 1; 1; 1] = [5; 50; 500]
|
//[1 1 1 1 1; 10 10 10 10 10; 100 100 100 100 100] x [1; 1; 1; 1; 1] = [5; 50; 500]
|
||||||
INDArray matrix = Nd4j.ones(3, 5);
|
INDArray matrix = Nd4j.ones(3, 5).castTo(DataType.DOUBLE);
|
||||||
matrix.getRow(1).muli(10);
|
matrix.getRow(1).muli(10);
|
||||||
matrix.getRow(2).muli(100);
|
matrix.getRow(2).muli(100);
|
||||||
|
|
||||||
INDArray colVector = Nd4j.ones(5, 1);
|
INDArray colVector = Nd4j.ones(5, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray out = matrix.mmul(colVector);
|
INDArray out = matrix.mmul(colVector);
|
||||||
|
|
||||||
INDArray expected = Nd4j.create(new double[] {5, 50, 500}, new long[] {3, 1});
|
INDArray expected = Nd4j.create(new double[] {5, 50, 500}, new long[] {3, 1});
|
||||||
|
@ -2449,8 +2439,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMMulMixedOrder(Nd4jBackend backend) {
|
public void testMMulMixedOrder(Nd4jBackend backend) {
|
||||||
INDArray first = Nd4j.ones(5, 2);
|
INDArray first = Nd4j.ones(5, 2).castTo(DataType.DOUBLE);
|
||||||
INDArray second = Nd4j.ones(2, 3);
|
INDArray second = Nd4j.ones(2, 3).castTo(DataType.DOUBLE);
|
||||||
INDArray out = first.mmul(second);
|
INDArray out = first.mmul(second);
|
||||||
assertArrayEquals(out.shape(), new long[] {5, 3});
|
assertArrayEquals(out.shape(), new long[] {5, 3});
|
||||||
assertTrue(out.equals(Nd4j.ones(5, 3).muli(2)));
|
assertTrue(out.equals(Nd4j.ones(5, 3).muli(2)));
|
||||||
|
@ -2475,9 +2465,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testFTimesCAddiRow(Nd4jBackend backend) {
|
public void testFTimesCAddiRow(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray arrF = Nd4j.create(2, 3, 'f').assign(1.0);
|
INDArray arrF = Nd4j.create(2, 3, 'f').assign(1.0).castTo(DataType.DOUBLE);
|
||||||
INDArray arrC = Nd4j.create(2, 3, 'c').assign(1.0);
|
INDArray arrC = Nd4j.create(2, 3, 'c').assign(1.0).castTo(DataType.DOUBLE);
|
||||||
INDArray arr2 = Nd4j.create(new long[] {3, 4}, 'c').assign(1.0);
|
INDArray arr2 = Nd4j.create(new long[] {3, 4}, 'c').assign(1.0).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray mmulC = arrC.mmul(arr2); //[2,4] with elements 3.0
|
INDArray mmulC = arrC.mmul(arr2); //[2,4] with elements 3.0
|
||||||
INDArray mmulF = arrF.mmul(arr2); //[2,4] with elements 3.0
|
INDArray mmulF = arrF.mmul(arr2); //[2,4] with elements 3.0
|
||||||
|
@ -2485,7 +2475,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
assertArrayEquals(mmulF.shape(), new long[] {2, 4});
|
assertArrayEquals(mmulF.shape(), new long[] {2, 4});
|
||||||
assertTrue(arrC.equals(arrF));
|
assertTrue(arrC.equals(arrF));
|
||||||
|
|
||||||
INDArray row = Nd4j.zeros(1, 4).assign(0.0).addi(0.5);
|
INDArray row = Nd4j.zeros(1, 4).assign(0.0).addi(0.5).castTo(DataType.DOUBLE);
|
||||||
mmulC.addiRowVector(row); //OK
|
mmulC.addiRowVector(row); //OK
|
||||||
mmulF.addiRowVector(row); //Exception
|
mmulF.addiRowVector(row); //Exception
|
||||||
|
|
||||||
|
@ -2503,8 +2493,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMmulGet(Nd4jBackend backend) {
|
public void testMmulGet(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345L);
|
Nd4j.getRandom().setSeed(12345L);
|
||||||
INDArray elevenByTwo = Nd4j.rand(new long[] {11, 2});
|
INDArray elevenByTwo = Nd4j.rand(new long[] {11, 2}).castTo(DataType.DOUBLE);
|
||||||
INDArray twoByEight = Nd4j.rand(new long[] {2, 8});
|
INDArray twoByEight = Nd4j.rand(new long[] {2, 8}).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray view = twoByEight.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2));
|
INDArray view = twoByEight.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2));
|
||||||
INDArray viewCopy = view.dup();
|
INDArray viewCopy = view.dup();
|
||||||
|
@ -2520,15 +2510,15 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMMulRowColVectorMixedOrder(Nd4jBackend backend) {
|
public void testMMulRowColVectorMixedOrder(Nd4jBackend backend) {
|
||||||
INDArray colVec = Nd4j.ones(5, 1);
|
INDArray colVec = Nd4j.ones(5, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray rowVec = Nd4j.ones(1, 3);
|
INDArray rowVec = Nd4j.ones(1, 3).castTo(DataType.DOUBLE);
|
||||||
INDArray out = colVec.mmul(rowVec);
|
INDArray out = colVec.mmul(rowVec);
|
||||||
assertArrayEquals(out.shape(), new long[] {5, 3});
|
assertArrayEquals(out.shape(), new long[] {5, 3});
|
||||||
assertTrue(out.equals(Nd4j.ones(5, 3)));
|
assertTrue(out.equals(Nd4j.ones(5, 3)));
|
||||||
//Above: OK
|
//Above: OK
|
||||||
|
|
||||||
INDArray colVectorC = Nd4j.create(new long[] {5, 1}, 'c');
|
INDArray colVectorC = Nd4j.create(new long[] {5, 1}, 'c').castTo(DataType.DOUBLE);
|
||||||
INDArray rowVectorF = Nd4j.create(new long[] {1, 3}, 'f');
|
INDArray rowVectorF = Nd4j.create(new long[] {1, 3}, 'f').castTo(DataType.DOUBLE);
|
||||||
for (int i = 0; i < colVectorC.length(); i++)
|
for (int i = 0; i < colVectorC.length(); i++)
|
||||||
colVectorC.putScalar(i, 1.0);
|
colVectorC.putScalar(i, 1.0);
|
||||||
for (int i = 0; i < rowVectorF.length(); i++)
|
for (int i = 0; i < rowVectorF.length(); i++)
|
||||||
|
@ -2548,9 +2538,9 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
int nCols = 3;
|
int nCols = 3;
|
||||||
java.util.Random r = new java.util.Random(12345);
|
java.util.Random r = new java.util.Random(12345);
|
||||||
|
|
||||||
INDArray arrC = Nd4j.create(new long[] {nRows, nCols}, 'c');
|
INDArray arrC = Nd4j.create(new long[] {nRows, nCols}, 'c').castTo(DataType.DOUBLE);
|
||||||
INDArray arrF = Nd4j.create(new long[] {nRows, nCols}, 'f');
|
INDArray arrF = Nd4j.create(new long[] {nRows, nCols}, 'f').castTo(DataType.DOUBLE);
|
||||||
INDArray arrC2 = Nd4j.create(new long[] {nRows, nCols}, 'c');
|
INDArray arrC2 = Nd4j.create(new long[] {nRows, nCols}, 'c').castTo(DataType.DOUBLE);
|
||||||
for (int i = 0; i < nRows; i++) {
|
for (int i = 0; i < nRows; i++) {
|
||||||
for (int j = 0; j < nCols; j++) {
|
for (int j = 0; j < nCols; j++) {
|
||||||
double rv = r.nextDouble();
|
double rv = r.nextDouble();
|
||||||
|
@ -2570,8 +2560,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testMMulColVectorRowVectorMixedOrder(Nd4jBackend backend) {
|
public void testMMulColVectorRowVectorMixedOrder(Nd4jBackend backend) {
|
||||||
INDArray colVec = Nd4j.ones(5, 1);
|
INDArray colVec = Nd4j.ones(5, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray rowVec = Nd4j.ones(1, 5);
|
INDArray rowVec = Nd4j.ones(1, 5).castTo(DataType.DOUBLE);
|
||||||
INDArray out = rowVec.mmul(colVec);
|
INDArray out = rowVec.mmul(colVec);
|
||||||
assertArrayEquals(new long[] {1, 1}, out.shape());
|
assertArrayEquals(new long[] {1, 1}, out.shape());
|
||||||
assertTrue(out.equals(Nd4j.ones(1, 1).muli(5)));
|
assertTrue(out.equals(Nd4j.ones(1, 1).muli(5)));
|
||||||
|
@ -2593,7 +2583,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermute(Nd4jBackend backend) {
|
public void testPermute(Nd4jBackend backend) {
|
||||||
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4});
|
INDArray n = Nd4j.create(Nd4j.linspace(1, 20, 20, DataType.DOUBLE).data(), new long[] {5, 4}).castTo(DataType.DOUBLE);
|
||||||
INDArray transpose = n.transpose();
|
INDArray transpose = n.transpose();
|
||||||
INDArray permute = n.permute(1, 0);
|
INDArray permute = n.permute(1, 0);
|
||||||
assertEquals(permute, transpose);
|
assertEquals(permute, transpose);
|
||||||
|
@ -2612,7 +2602,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
//Check in-place permute vs. copy array permute
|
//Check in-place permute vs. copy array permute
|
||||||
|
|
||||||
//2d:
|
//2d:
|
||||||
INDArray orig = Nd4j.linspace(1, 3 * 4, 3 * 4, DataType.DOUBLE).reshape('c', 3, 4);
|
INDArray orig = Nd4j.linspace(1, 3 * 4, 3 * 4, DataType.DOUBLE).reshape('c', 3, 4).castTo(DataType.DOUBLE);
|
||||||
INDArray exp01 = orig.permute(0, 1);
|
INDArray exp01 = orig.permute(0, 1);
|
||||||
INDArray exp10 = orig.permute(1, 0);
|
INDArray exp10 = orig.permute(1, 0);
|
||||||
List<Pair<INDArray, String>> list1 = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE);
|
List<Pair<INDArray, String>> list1 = NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345, DataType.DOUBLE);
|
||||||
|
@ -2692,7 +2682,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testPermuteiShape(Nd4jBackend backend) {
|
public void testPermuteiShape(Nd4jBackend backend) {
|
||||||
|
|
||||||
INDArray row = Nd4j.create(1, 10);
|
INDArray row = Nd4j.create(1, 10).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray permutedCopy = row.permute(1, 0);
|
INDArray permutedCopy = row.permute(1, 0);
|
||||||
INDArray permutedInplace = row.permutei(1, 0);
|
INDArray permutedInplace = row.permutei(1, 0);
|
||||||
|
@ -2797,7 +2787,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testStdev1(Nd4jBackend backend) {
|
public void testStdev1(Nd4jBackend backend) {
|
||||||
double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}};
|
double[][] ind = {{5.1, 3.5, 1.4}, {4.9, 3.0, 1.4}, {4.7, 3.2, 1.3}};
|
||||||
INDArray in = Nd4j.create(ind);
|
INDArray in = Nd4j.create(ind).castTo(DataType.DOUBLE);
|
||||||
INDArray stdev = in.std(1);
|
INDArray stdev = in.std(1);
|
||||||
// log.info("StdDev: {}", stdev.toDoubleVector());
|
// log.info("StdDev: {}", stdev.toDoubleVector());
|
||||||
INDArray exp = Nd4j.create(new double[] {1.8556220879622372, 1.7521415467935233, 1.7039170558842744});
|
INDArray exp = Nd4j.create(new double[] {1.8556220879622372, 1.7521415467935233, 1.7039170558842744});
|
||||||
|
@ -2811,8 +2801,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
double[] d = {1.0, -1.1, 1.2, 1.3, -1.4, -1.5, 1.6, -1.7, -1.8, -1.9, -1.01, -1.011};
|
double[] d = {1.0, -1.1, 1.2, 1.3, -1.4, -1.5, 1.6, -1.7, -1.8, -1.9, -1.01, -1.011};
|
||||||
double[] e = {1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0};
|
double[] e = {1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0};
|
||||||
|
|
||||||
INDArray arrF = Nd4j.create(d, new long[] {4, 3}, 'f');
|
INDArray arrF = Nd4j.create(d, new long[] {4, 3}, 'f').castTo(DataType.DOUBLE);
|
||||||
INDArray arrC = Nd4j.create(new long[] {4, 3}, 'c').assign(arrF);
|
INDArray arrC = Nd4j.create(new long[] {4, 3}, 'c').assign(arrF).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray exp = Nd4j.create(e, new long[] {4, 3}, 'f');
|
INDArray exp = Nd4j.create(e, new long[] {4, 3}, 'f');
|
||||||
|
|
||||||
|
@ -3036,7 +3026,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTemp(Nd4jBackend backend) {
|
public void testTemp(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray in = Nd4j.rand(new long[] {2, 2, 2});
|
INDArray in = Nd4j.rand(new long[] {2, 2, 2}).castTo(DataType.DOUBLE);
|
||||||
// System.out.println("In:\n" + in);
|
// System.out.println("In:\n" + in);
|
||||||
INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping
|
INDArray permuted = in.permute(0, 2, 1); //Permute, so we get correct order after reshaping
|
||||||
INDArray out = permuted.reshape(4, 2);
|
INDArray out = permuted.reshape(4, 2);
|
||||||
|
@ -4225,7 +4215,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
val s = new long[] {2, 3, 4, 5};
|
val s = new long[] {2, 3, 4, 5};
|
||||||
INDArray arr = Nd4j.rand(s);
|
INDArray arr = Nd4j.rand(s).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
//Test 0,1
|
//Test 0,1
|
||||||
INDArray exp = Nd4j.create(DataType.BOOL, s);
|
INDArray exp = Nd4j.create(DataType.BOOL, s);
|
||||||
|
@ -4328,7 +4318,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testIMax2of4d(Nd4jBackend backend) {
|
public void testIMax2of4d(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
val s = new long[] {2, 3, 4, 5};
|
val s = new long[] {2, 3, 4, 5};
|
||||||
INDArray arr = Nd4j.rand(s);
|
INDArray arr = Nd4j.rand(s).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
//Test 0,1
|
//Test 0,1
|
||||||
INDArray exp = Nd4j.create(DataType.LONG, 4, 5);
|
INDArray exp = Nd4j.create(DataType.LONG, 4, 5);
|
||||||
|
@ -5368,7 +5358,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testAllDistances3(Nd4jBackend backend) {
|
public void testAllDistances3(Nd4jBackend backend) {
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
|
|
||||||
INDArray initialX = Nd4j.rand(5, 10);
|
INDArray initialX = Nd4j.rand(5, 10).castTo(DataType.DOUBLE);
|
||||||
INDArray initialY = initialX.mul(-1);
|
INDArray initialY = initialX.mul(-1);
|
||||||
|
|
||||||
INDArray result = Transforms.allCosineSimilarities(initialX, initialY, 1);
|
INDArray result = Transforms.allCosineSimilarities(initialX, initialY, 1);
|
||||||
|
@ -5421,7 +5411,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEntropy1(Nd4jBackend backend) {
|
public void testEntropy1(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.rand(1, 100);
|
INDArray x = Nd4j.rand(1, 100).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
double exp = MathUtils.entropy(x.data().asDouble());
|
double exp = MathUtils.entropy(x.data().asDouble());
|
||||||
double res = x.entropyNumber().doubleValue();
|
double res = x.entropyNumber().doubleValue();
|
||||||
|
@ -5432,7 +5422,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEntropy2(Nd4jBackend backend) {
|
public void testEntropy2(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.rand(10, 100);
|
INDArray x = Nd4j.rand(10, 100).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray res = x.entropy(1);
|
INDArray res = x.entropy(1);
|
||||||
|
|
||||||
|
@ -5448,7 +5438,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEntropy3(Nd4jBackend backend) {
|
public void testEntropy3(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.rand(1, 100);
|
INDArray x = Nd4j.rand(1, 100).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
double exp = getShannonEntropy(x.data().asDouble());
|
double exp = getShannonEntropy(x.data().asDouble());
|
||||||
double res = x.shannonEntropyNumber().doubleValue();
|
double res = x.shannonEntropyNumber().doubleValue();
|
||||||
|
@ -5459,7 +5449,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testEntropy4(Nd4jBackend backend) {
|
public void testEntropy4(Nd4jBackend backend) {
|
||||||
INDArray x = Nd4j.rand(1, 100);
|
INDArray x = Nd4j.rand(1, 100).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
double exp = getLogEntropy(x.data().asDouble());
|
double exp = getLogEntropy(x.data().asDouble());
|
||||||
double res = x.logEntropyNumber().doubleValue();
|
double res = x.logEntropyNumber().doubleValue();
|
||||||
|
@ -5584,7 +5574,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testNativeSort2(Nd4jBackend backend) {
|
public void testNativeSort2(Nd4jBackend backend) {
|
||||||
INDArray array = Nd4j.rand(1, 10000);
|
INDArray array = Nd4j.rand(1, 10000).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray res = Nd4j.sort(array, true);
|
INDArray res = Nd4j.sort(array, true);
|
||||||
INDArray exp = res.dup();
|
INDArray exp = res.dup();
|
||||||
|
@ -6779,12 +6769,12 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testInconsistentOutput(){
|
public void testInconsistentOutput(){
|
||||||
INDArray in = Nd4j.rand(1, 802816);
|
INDArray in = Nd4j.rand(1, 802816).castTo(DataType.DOUBLE);
|
||||||
INDArray W = Nd4j.rand(802816, 1);
|
INDArray W = Nd4j.rand(802816, 1).castTo(DataType.DOUBLE);
|
||||||
INDArray b = Nd4j.create(1);
|
INDArray b = Nd4j.create(1).castTo(DataType.DOUBLE);
|
||||||
INDArray out = fwd(in, W, b);
|
INDArray out = fwd(in, W, b);
|
||||||
|
|
||||||
for(int i=0;i<100;i++){
|
for(int i = 0;i < 100;i++) {
|
||||||
INDArray out2 = fwd(in, W, b); //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces());
|
INDArray out2 = fwd(in, W, b); //l.activate(inToLayer1, false, LayerWorkspaceMgr.noWorkspaces());
|
||||||
assertEquals( out, out2,"Failed at iteration [" + String.valueOf(i) + "]");
|
assertEquals( out, out2,"Failed at iteration [" + String.valueOf(i) + "]");
|
||||||
}
|
}
|
||||||
|
@ -6964,7 +6954,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
for(boolean biasCorrected : new boolean[]{false, true}) {
|
for(boolean biasCorrected : new boolean[]{false, true}) {
|
||||||
|
|
||||||
INDArray indArray1 = Nd4j.rand(1, 4, 10);
|
INDArray indArray1 = Nd4j.rand(1, 4, 10).castTo(DataType.DOUBLE);
|
||||||
double std = indArray1.stdNumber(biasCorrected).doubleValue();
|
double std = indArray1.stdNumber(biasCorrected).doubleValue();
|
||||||
|
|
||||||
val standardDeviation = new org.apache.commons.math3.stat.descriptive.moment.StandardDeviation(biasCorrected);
|
val standardDeviation = new org.apache.commons.math3.stat.descriptive.moment.StandardDeviation(biasCorrected);
|
||||||
|
@ -7104,7 +7094,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testTearPile_1(Nd4jBackend backend) {
|
public void testTearPile_1(Nd4jBackend backend) {
|
||||||
val source = Nd4j.rand(new int[]{10, 15});
|
val source = Nd4j.rand(new int[]{10, 15}).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
val list = Nd4j.tear(source, 1);
|
val list = Nd4j.tear(source, 1);
|
||||||
|
|
||||||
|
@ -7221,16 +7211,16 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
public void testEmptyShapeRank0(){
|
public void testEmptyShapeRank0(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
int[] s = new int[0];
|
int[] s = new int[0];
|
||||||
INDArray create = Nd4j.create(s);
|
INDArray create = Nd4j.create(s).castTo(DataType.DOUBLE);
|
||||||
INDArray zeros = Nd4j.zeros(s);
|
INDArray zeros = Nd4j.zeros(s).castTo(DataType.DOUBLE);
|
||||||
INDArray ones = Nd4j.ones(s);
|
INDArray ones = Nd4j.ones(s).castTo(DataType.DOUBLE);
|
||||||
INDArray uninit = Nd4j.createUninitialized(s).assign(0);
|
INDArray uninit = Nd4j.createUninitialized(s).assign(0).castTo(DataType.DOUBLE);
|
||||||
INDArray rand = Nd4j.rand(s);
|
INDArray rand = Nd4j.rand(s).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
INDArray tsZero = Nd4j.scalar(0.0);
|
INDArray tsZero = Nd4j.scalar(0.0).castTo(DataType.DOUBLE);
|
||||||
INDArray tsOne = Nd4j.scalar(1.0);
|
INDArray tsOne = Nd4j.scalar(1.0).castTo(DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray tsRand = Nd4j.scalar(Nd4j.rand(new int[]{1,1}).getDouble(0));
|
INDArray tsRand = Nd4j.scalar(Nd4j.rand(new int[]{1,1}).getDouble(0)).castTo(DataType.DOUBLE);
|
||||||
assertEquals(tsZero, create);
|
assertEquals(tsZero, create);
|
||||||
assertEquals(tsZero, zeros);
|
assertEquals(tsZero, zeros);
|
||||||
assertEquals(tsOne, ones);
|
assertEquals(tsOne, ones);
|
||||||
|
@ -7240,11 +7230,11 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
long[] s2 = new long[0];
|
long[] s2 = new long[0];
|
||||||
create = Nd4j.create(s2);
|
create = Nd4j.create(s2).castTo(DataType.DOUBLE);
|
||||||
zeros = Nd4j.zeros(s2);
|
zeros = Nd4j.zeros(s2).castTo(DataType.DOUBLE);
|
||||||
ones = Nd4j.ones(s2);
|
ones = Nd4j.ones(s2).castTo(DataType.DOUBLE);
|
||||||
uninit = Nd4j.createUninitialized(s2).assign(0);
|
uninit = Nd4j.createUninitialized(s2).assign(0).castTo(DataType.DOUBLE);
|
||||||
rand = Nd4j.rand(s2);
|
rand = Nd4j.rand(s2).castTo(DataType.DOUBLE);
|
||||||
|
|
||||||
assertEquals(tsZero, create);
|
assertEquals(tsZero, create);
|
||||||
assertEquals(tsZero, zeros);
|
assertEquals(tsZero, zeros);
|
||||||
|
@ -7709,8 +7699,8 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||||
public void testINDArrayMmulWithTranspose(){
|
public void testINDArrayMmulWithTranspose(){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray a = Nd4j.rand(2,5);
|
INDArray a = Nd4j.rand(2,5).castTo(DataType.DOUBLE);
|
||||||
INDArray b = Nd4j.rand(5,3);
|
INDArray b = Nd4j.rand(5,3).castTo(DataType.DOUBLE);
|
||||||
INDArray exp = a.mmul(b);
|
INDArray exp = a.mmul(b);
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
|
@ -7720,26 +7710,26 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
|
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
|
|
||||||
a = Nd4j.rand(5,2);
|
a = Nd4j.rand(5,2).castTo(DataType.DOUBLE);
|
||||||
b = Nd4j.rand(5,3);
|
b = Nd4j.rand(5,3).castTo(DataType.DOUBLE);
|
||||||
exp = a.transpose().mmul(b);
|
exp = a.transpose().mmul(b);
|
||||||
act = a.mmul(b, MMulTranspose.builder().transposeA(true).build());
|
act = a.mmul(b, MMulTranspose.builder().transposeA(true).build());
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
|
|
||||||
a = Nd4j.rand(2,5);
|
a = Nd4j.rand(2,5).castTo(DataType.DOUBLE);
|
||||||
b = Nd4j.rand(3,5);
|
b = Nd4j.rand(3,5).castTo(DataType.DOUBLE);
|
||||||
exp = a.mmul(b.transpose());
|
exp = a.mmul(b.transpose());
|
||||||
act = a.mmul(b, MMulTranspose.builder().transposeB(true).build());
|
act = a.mmul(b, MMulTranspose.builder().transposeB(true).build());
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
|
|
||||||
a = Nd4j.rand(5,2);
|
a = Nd4j.rand(5,2).castTo(DataType.DOUBLE);
|
||||||
b = Nd4j.rand(3,5);
|
b = Nd4j.rand(3,5).castTo(DataType.DOUBLE);
|
||||||
exp = a.transpose().mmul(b.transpose());
|
exp = a.transpose().mmul(b.transpose());
|
||||||
act = a.mmul(b, MMulTranspose.builder().transposeA(true).transposeB(true).build());
|
act = a.mmul(b, MMulTranspose.builder().transposeA(true).transposeB(true).build());
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
|
|
||||||
a = Nd4j.rand(5,2);
|
a = Nd4j.rand(5,2).castTo(DataType.DOUBLE);
|
||||||
b = Nd4j.rand(3,5);
|
b = Nd4j.rand(3,5).castTo(DataType.DOUBLE);
|
||||||
exp = a.transpose().mmul(b.transpose()).transpose();
|
exp = a.transpose().mmul(b.transpose()).transpose();
|
||||||
act = a.mmul(b, MMulTranspose.builder().transposeA(true).transposeB(true).transposeResult(true).build());
|
act = a.mmul(b, MMulTranspose.builder().transposeA(true).transposeB(true).transposeResult(true).build());
|
||||||
assertEquals(exp, act);
|
assertEquals(exp, act);
|
||||||
|
@ -7778,7 +7768,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Nd4j.rand('z', 1, 1);
|
Nd4j.rand('z', 1, 1).castTo(DataType.DOUBLE);
|
||||||
fail("Expected failure");
|
fail("Expected failure");
|
||||||
} catch (IllegalArgumentException e){
|
} catch (IllegalArgumentException e){
|
||||||
assertTrue(e.getMessage().toLowerCase().contains("order"));
|
assertTrue(e.getMessage().toLowerCase().contains("order"));
|
||||||
|
@ -8300,7 +8290,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends {
|
||||||
@Disabled
|
@Disabled
|
||||||
public void testType1(Nd4jBackend backend) throws IOException {
|
public void testType1(Nd4jBackend backend) throws IOException {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int i = 0; i < 10; ++i) {
|
||||||
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100});
|
INDArray in1 = Nd4j.rand(DataType.DOUBLE, new int[]{100, 100}).castTo(DataType.DOUBLE);
|
||||||
File dir = testDir.toFile();
|
File dir = testDir.toFile();
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir,"test.bin")));
|
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir,"test.bin")));
|
||||||
oos.writeObject(in1);
|
oos.writeObject(in1);
|
||||||
|
|
Loading…
Reference in New Issue