From 15e7984392d3969dd0c080ec79eecae8d8cd2131 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 5 Jul 2019 15:15:09 +0300 Subject: [PATCH] Small fixes (#46) * - fix for eclipse/#7959 * createFromNpy scalar fix Signed-off-by: raver119 * remove spam Signed-off-by: raver119 * rng custom ops tests Signed-off-by: raver119 * Numpy headers validation + tests Signed-off-by: raver119 * fix for scalar string flat serde Signed-off-by: raver119 * Where empty shape test Signed-off-by: raver119 --- libnd4j/blas/NativeOps.h | 1 + libnd4j/blas/cpu/NativeOps.cpp | 7 +- libnd4j/blas/cuda/NativeOps.cu | 7 +- libnd4j/include/cnpy/cnpy.cpp | 12 +++ .../java/org/nd4j/linalg/api/shape/Shape.java | 3 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 8 +- .../nativeblas/BaseNativeNDArrayFactory.java | 1 - .../java/org/nd4j/nativeblas/Nd4jCuda.java | 89 +++++++++++++++++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 1 + .../nd4j/linalg/mixed/StringArrayTests.java | 18 ++++ .../java/org/nd4j/linalg/rng/RandomTests.java | 29 ++++++ .../nd4j/linalg/serde/NumpyFormatTests.java | 48 ++++++++++ .../org/nd4j/linalg/shape/EmptyTests.java | 9 ++ .../org/nd4j/linalg/shape/ShapeTestsC.java | 14 ++- 14 files changed, 235 insertions(+), 12 deletions(-) diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index d2b124885..1c818d528 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -67,6 +67,7 @@ bool verbose = false; #include #include +#include #include #include #include diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 6a92e0825..d281bdfac 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2725,7 +2725,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); Nd4jLong *shapeBuffer; - if (_empty) { + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { if (shapeSize > 0) shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); else @@ -2733,7 +2736,7 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { } else { shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); } - return reinterpret_cast(shapeBuffer); + return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); } void NativeOps::sortByKey(Nd4jPointer *extraPointers, diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 87fa93223..4fa3a36fa 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3276,7 +3276,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); Nd4jLong *shapeBuffer; - if (_empty) { + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { if (shapeSize > 0) shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); else @@ -3284,5 +3287,5 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { } else { shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); } - return reinterpret_cast(shapeBuffer); + return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); } diff --git a/libnd4j/include/cnpy/cnpy.cpp b/libnd4j/include/cnpy/cnpy.cpp index 0d19dbe82..79fe4cdb9 100644 --- a/libnd4j/include/cnpy/cnpy.cpp +++ b/libnd4j/include/cnpy/cnpy.cpp @@ -393,6 +393,18 @@ cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) { * @return */ cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) { + // check for magic header + if (data == nullptr) + throw std::runtime_error("NULL pointer doesn't look like a NumPy header"); + + if (data[0] == (char) 0x93) { + std::vector exp({(char) 0x93, 'N', 'U', 'M', 'P', 'Y', (char) 0x01}); + std::vector hdr(data, data+7); + if (hdr != exp) + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + } else + throw std::runtime_error("Pointer doesn't look like a NumPy header"); + //move passed magic data += 11; unsigned int *shape; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index e561b2257..89e8cbec4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -2077,7 +2077,8 @@ public class Shape { newStrides[nk] = last_stride; } - INDArray ret = Nd4j.create(arr.data(), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c'); + // we need to wrap buffer of a current array, to make sure it's properly marked as a View + INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c'); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index c441a10f8..1ac932584 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -6996,14 +6996,14 @@ public class Nd4j { case UTF8: { try { val sb = bb.order(_order); - val pos = bb.position(); - val arr = new byte[sb.limit() - sb.position()]; + val pos = sb.position(); + val arr = new byte[sb.limit() - pos]; for (int e = 0; e < arr.length; e++) { - arr[e] = sb.get(e + sb.position()); + arr[e] = sb.get(e + pos); } - val buffer = new Utf8Buffer(arr, ArrayUtil.prod(shapeOf)); + val buffer = new Utf8Buffer(arr, prod); return Nd4j.create(buffer, shapeOf); } catch (Exception e) { throw new RuntimeException(e); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index 440b096c3..3b7ac83c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -140,7 +140,6 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { dataPointer.capacity(dataBufferElementSize * Shape.length(shapeBuffer)); val jvmShapeInfo = shapeBuffer.asLong(); - log.info("JVM shapeInfo: {}", jvmShapeInfo); val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo); switch (dtype) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 911b33414..3f7794074 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -678,6 +678,7 @@ bool verbose = false; // #include // #include +// #include // #include // #include // #include @@ -2810,6 +2811,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + @Cast("bool") boolean descending); + + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + @Cast("bool") boolean descending); + public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers, Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, @@ -2835,6 +2874,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps { @Cast("Nd4jLong*") long[] tadOffsets, @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + IntPointer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + IntBuffer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + int[] dimension, + int dimensionLength, + @Cast("bool") boolean descending); + + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo, + IntPointer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo, + IntBuffer dimension, + int dimensionLength, + @Cast("bool") boolean descending); + public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers, + Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo, + Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo, + Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo, + Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo, + int[] dimension, + int dimensionLength, + @Cast("bool") boolean descending); + // special sort impl for sorting out COO indices and values public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index a41a1e3c7..cc3d2f284 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -678,6 +678,7 @@ bool verbose = false; // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java index 1d5ae1077..80030512a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/StringArrayTests.java @@ -103,6 +103,24 @@ public class StringArrayTests extends BaseNd4jTest { assertEquals("gamma", restored.getString(2)); } + @Test + public void testBasicStrings_4a() { + val arrayX = Nd4j.scalar("alpha"); + + val fb = new FlatBufferBuilder(); + val i = arrayX.toFlatArray(fb); + fb.finish(i); + val db = fb.dataBuffer(); + + val flat = FlatArray.getRootAsFlatArray(db); + val restored = Nd4j.createFromFlatArray(flat); + + assertEquals("alpha", arrayX.getString(0)); + + assertEquals(arrayX, restored); + assertEquals("alpha", restored.getString(0)); + } + @Test public void testBasicStrings_5() { val arrayX = Nd4j.create("alpha", "beta", "gamma"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 98829eddd..66979b95d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -31,6 +31,8 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; +import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; +import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.Random; @@ -1424,6 +1426,33 @@ public class RandomTests extends BaseNd4jTest { return out; } + + @Test + public void testRngRepeatabilityUniform(){ + Nd4j.getRandom().setSeed(12345); + INDArray out1 = Nd4j.create(DataType.FLOAT, 10); + Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0)); + + Nd4j.getRandom().setSeed(12345); + INDArray out2 = Nd4j.create(DataType.FLOAT, 10); + Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0)); + + assertEquals(out1, out2); + } + + @Test + public void testRngRepeatabilityBernoulli(){ + Nd4j.getRandom().setSeed(12345); + INDArray out1 = Nd4j.create(DataType.FLOAT, 10); + Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out1, 0.5)); + + Nd4j.getRandom().setSeed(12345); + INDArray out2 = Nd4j.create(DataType.FLOAT, 10); + Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out2, 0.5)); + + assertEquals(out1, out2); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index caec61321..06a0b62fb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -256,6 +256,54 @@ public class NumpyFormatTests extends BaseNd4jTest { } } + @Test + public void testFromNumpyScalar() throws Exception { + val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile()); + assertEquals(Nd4j.scalar(DataType.INT, 1), out); + } + + @Test(expected = RuntimeException.class) + public void readNumpyCorruptHeader1() throws Exception { + File f = testDir.newFolder(); + + File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); + byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); + for( int i=0; i<10; i++ ){ + numpyBytes[i] = 0; + } + File fCorrupt = new File(f, "corrupt.npy"); + FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); + + INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); + + INDArray act1 = Nd4j.createFromNpyFile(fValid); + assertEquals(exp, act1); + + INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine + boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content + } + + @Test(expected = RuntimeException.class) + public void readNumpyCorruptHeader2() throws Exception { + File f = testDir.newFolder(); + + File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile(); + byte[] numpyBytes = FileUtils.readFileToByteArray(fValid); + for( int i=1; i<10; i++ ){ + numpyBytes[i] = 0; + } + File fCorrupt = new File(f, "corrupt.npy"); + FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes); + + INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4); + + INDArray act1 = Nd4j.createFromNpyFile(fValid); + assertEquals(exp, act1); + + INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine + boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index ce0b94b81..d492d8612 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -289,6 +289,15 @@ public class EmptyTests extends BaseNd4jTest { assertFalse(Nd4j.create(0).equalShapes(Nd4j.create(1, 0))); } + @Test + public void testEmptyWhere() { + val mask = Nd4j.createFromArray(false, false, false, false, false); + val result = Nd4j.where(mask, null, null); + + assertTrue(result[0].isEmpty()); + assertNotNull(result[0].shapeInfoDataBuffer().asLong()); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java index 06b593df9..22133976b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeTestsC.java @@ -29,8 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; /** * @author Adam Gibson @@ -489,6 +488,17 @@ public class ShapeTestsC extends BaseNd4jTest { assertEquals(exp, reshaped); } + @Test + public void testViewAfterReshape() { + val x = Nd4j.rand(3,4); + val x2 = x.ravel(); + val x3 = x.reshape(6,2); + + assertFalse(x.isView()); + assertTrue(x2.isView()); + assertTrue(x3.isView()); + } + @Override public char ordering() { return 'c';