Small fixes (#46)

* - fix for eclipse/#7959

* createFromNpy scalar fix

Signed-off-by: raver119 <raver119@gmail.com>

* remove spam

Signed-off-by: raver119 <raver119@gmail.com>

* rng custom ops tests

Signed-off-by: raver119 <raver119@gmail.com>

* Numpy headers validation + tests

Signed-off-by: raver119 <raver119@gmail.com>

* fix for scalar string flat serde

Signed-off-by: raver119 <raver119@gmail.com>

* Where empty shape test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-05 15:15:09 +03:00 committed by AlexDBlack
parent c135883162
commit 15e7984392
14 changed files with 235 additions and 12 deletions

View File

@ -67,6 +67,7 @@ bool verbose = false;
#include <array/ShapeList.h>
#include <array/ConstantDescriptor.h>
#include <helpers/ConstantShapeHelper.h>
#include <array/ConstantDataBuffer.h>
#include <helpers/ConstantHelper.h>
#include <array/TadPack.h>

View File

@ -2725,7 +2725,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shape.size() == 1 && shape[0] == 0) {
// scalar case
shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype);
} else if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
@ -2733,7 +2736,7 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
return reinterpret_cast<Nd4jPointer>(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
}
void NativeOps::sortByKey(Nd4jPointer *extraPointers,

View File

@ -3276,7 +3276,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shape.size() == 1 && shape[0] == 0) {
// scalar case
shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype);
} else if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
@ -3284,5 +3287,5 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
return reinterpret_cast<Nd4jPointer>(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
}

View File

@ -393,6 +393,18 @@ cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) {
* @return
*/
cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) {
// check for magic header
if (data == nullptr)
throw std::runtime_error("NULL pointer doesn't look like a NumPy header");
if (data[0] == (char) 0x93) {
std::vector<char> exp({(char) 0x93, 'N', 'U', 'M', 'P', 'Y', (char) 0x01});
std::vector<char> hdr(data, data+7);
if (hdr != exp)
throw std::runtime_error("Pointer doesn't look like a NumPy header");
} else
throw std::runtime_error("Pointer doesn't look like a NumPy header");
//move passed magic
data += 11;
unsigned int *shape;

View File

@ -2077,7 +2077,8 @@ public class Shape {
newStrides[nk] = last_stride;
}
INDArray ret = Nd4j.create(arr.data(), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
// we need to wrap buffer of a current array, to make sure it's properly marked as a View
INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
return ret;
}

View File

@ -6996,14 +6996,14 @@ public class Nd4j {
case UTF8: {
try {
val sb = bb.order(_order);
val pos = bb.position();
val arr = new byte[sb.limit() - sb.position()];
val pos = sb.position();
val arr = new byte[sb.limit() - pos];
for (int e = 0; e < arr.length; e++) {
arr[e] = sb.get(e + sb.position());
arr[e] = sb.get(e + pos);
}
val buffer = new Utf8Buffer(arr, ArrayUtil.prod(shapeOf));
val buffer = new Utf8Buffer(arr, prod);
return Nd4j.create(buffer, shapeOf);
} catch (Exception e) {
throw new RuntimeException(e);

View File

@ -140,7 +140,6 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
dataPointer.capacity(dataBufferElementSize * Shape.length(shapeBuffer));
val jvmShapeInfo = shapeBuffer.asLong();
log.info("JVM shapeInfo: {}", jvmShapeInfo);
val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);
switch (dtype) {

View File

@ -678,6 +678,7 @@ bool verbose = false;
// #include <array/ShapeList.h>
// #include <array/ConstantDescriptor.h>
// #include <helpers/ConstantShapeHelper.h>
// #include <array/ConstantDataBuffer.h>
// #include <helpers/ConstantHelper.h>
// #include <array/TadPack.h>
@ -2810,6 +2811,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
@Cast("bool") boolean descending);
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
@ -2835,6 +2874,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
@Cast("Nd4jLong*") long[] tadOffsets,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
IntPointer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
IntBuffer dimension,
int dimensionLength,
@Cast("bool") boolean descending);
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
int[] dimension,
int dimensionLength,
@Cast("bool") boolean descending);
// special sort impl for sorting out COO indices and values
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);

View File

@ -678,6 +678,7 @@ bool verbose = false;
// #include <array/ShapeList.h>
// #include <array/ConstantDescriptor.h>
// #include <helpers/ConstantShapeHelper.h>
// #include <array/ConstantDataBuffer.h>
// #include <helpers/ConstantHelper.h>
// #include <array/TadPack.h>

View File

@ -103,6 +103,24 @@ public class StringArrayTests extends BaseNd4jTest {
assertEquals("gamma", restored.getString(2));
}
@Test
public void testBasicStrings_4a() {
val arrayX = Nd4j.scalar("alpha");
val fb = new FlatBufferBuilder();
val i = arrayX.toFlatArray(fb);
fb.finish(i);
val db = fb.dataBuffer();
val flat = FlatArray.getRootAsFlatArray(db);
val restored = Nd4j.createFromFlatArray(flat);
assertEquals("alpha", arrayX.getString(0));
assertEquals(arrayX, restored);
assertEquals("alpha", restored.getString(0));
}
@Test
public void testBasicStrings_5() {
val arrayX = Nd4j.create("alpha", "beta", "gamma");

View File

@ -31,6 +31,8 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.impl.*;
import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random;
@ -1424,6 +1426,33 @@ public class RandomTests extends BaseNd4jTest {
return out;
}
@Test
public void testRngRepeatabilityUniform(){
Nd4j.getRandom().setSeed(12345);
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0));
Nd4j.getRandom().setSeed(12345);
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0));
assertEquals(out1, out2);
}
@Test
public void testRngRepeatabilityBernoulli(){
Nd4j.getRandom().setSeed(12345);
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out1, 0.5));
Nd4j.getRandom().setSeed(12345);
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out2, 0.5));
assertEquals(out1, out2);
}
@Override
public char ordering() {
return 'c';

View File

@ -256,6 +256,54 @@ public class NumpyFormatTests extends BaseNd4jTest {
}
}
@Test
public void testFromNumpyScalar() throws Exception {
val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile());
assertEquals(Nd4j.scalar(DataType.INT, 1), out);
}
@Test(expected = RuntimeException.class)
public void readNumpyCorruptHeader1() throws Exception {
File f = testDir.newFolder();
File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
for( int i=0; i<10; i++ ){
numpyBytes[i] = 0;
}
File fCorrupt = new File(f, "corrupt.npy");
FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);
INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);
INDArray act1 = Nd4j.createFromNpyFile(fValid);
assertEquals(exp, act1);
INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
}
@Test(expected = RuntimeException.class)
public void readNumpyCorruptHeader2() throws Exception {
File f = testDir.newFolder();
File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
for( int i=1; i<10; i++ ){
numpyBytes[i] = 0;
}
File fCorrupt = new File(f, "corrupt.npy");
FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);
INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);
INDArray act1 = Nd4j.createFromNpyFile(fValid);
assertEquals(exp, act1);
INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
}
@Override
public char ordering() {
return 'c';

View File

@ -289,6 +289,15 @@ public class EmptyTests extends BaseNd4jTest {
assertFalse(Nd4j.create(0).equalShapes(Nd4j.create(1, 0)));
}
@Test
public void testEmptyWhere() {
val mask = Nd4j.createFromArray(false, false, false, false, false);
val result = Nd4j.where(mask, null, null);
assertTrue(result[0].isEmpty());
assertNotNull(result[0].shapeInfoDataBuffer().asLong());
}
@Override
public char ordering() {
return 'c';

View File

@ -29,8 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
/**
* @author Adam Gibson
@ -489,6 +488,17 @@ public class ShapeTestsC extends BaseNd4jTest {
assertEquals(exp, reshaped);
}
@Test
public void testViewAfterReshape() {
val x = Nd4j.rand(3,4);
val x2 = x.ravel();
val x3 = x.reshape(6,2);
assertFalse(x.isView());
assertTrue(x2.isView());
assertTrue(x3.isView());
}
@Override
public char ordering() {
return 'c';