Small fixes (#46)
* - fix for eclipse/#7959 * createFromNpy scalar fix Signed-off-by: raver119 <raver119@gmail.com> * remove spam Signed-off-by: raver119 <raver119@gmail.com> * rng custom ops tests Signed-off-by: raver119 <raver119@gmail.com> * Numpy headers validation + tests Signed-off-by: raver119 <raver119@gmail.com> * fix for scalar string flat serde Signed-off-by: raver119 <raver119@gmail.com> * Where empty shape test Signed-off-by: raver119 <raver119@gmail.com>master
parent
c135883162
commit
15e7984392
|
@ -67,6 +67,7 @@ bool verbose = false;
|
|||
|
||||
#include <array/ShapeList.h>
|
||||
#include <array/ConstantDescriptor.h>
|
||||
#include <helpers/ConstantShapeHelper.h>
|
||||
#include <array/ConstantDataBuffer.h>
|
||||
#include <helpers/ConstantHelper.h>
|
||||
#include <array/TadPack.h>
|
||||
|
|
|
@ -2725,7 +2725,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||
|
||||
Nd4jLong *shapeBuffer;
|
||||
if (_empty) {
|
||||
if (shape.size() == 1 && shape[0] == 0) {
|
||||
// scalar case
|
||||
shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype);
|
||||
} else if (_empty) {
|
||||
if (shapeSize > 0)
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
else
|
||||
|
@ -2733,7 +2736,7 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
} else {
|
||||
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
return reinterpret_cast<Nd4jPointer>(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
|
||||
}
|
||||
|
||||
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
|
||||
|
|
|
@ -3276,7 +3276,10 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||
|
||||
Nd4jLong *shapeBuffer;
|
||||
if (_empty) {
|
||||
if (shape.size() == 1 && shape[0] == 0) {
|
||||
// scalar case
|
||||
shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype);
|
||||
} else if (_empty) {
|
||||
if (shapeSize > 0)
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
else
|
||||
|
@ -3284,5 +3287,5 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
} else {
|
||||
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
return reinterpret_cast<Nd4jPointer>(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true));
|
||||
}
|
||||
|
|
|
@ -393,6 +393,18 @@ cnpy::NpyArray cnpy::loadNpyFromPointer(char *data) {
|
|||
* @return
|
||||
*/
|
||||
cnpy::NpyArray cnpy::loadNpyFromHeader(char *data) {
|
||||
// check for magic header
|
||||
if (data == nullptr)
|
||||
throw std::runtime_error("NULL pointer doesn't look like a NumPy header");
|
||||
|
||||
if (data[0] == (char) 0x93) {
|
||||
std::vector<char> exp({(char) 0x93, 'N', 'U', 'M', 'P', 'Y', (char) 0x01});
|
||||
std::vector<char> hdr(data, data+7);
|
||||
if (hdr != exp)
|
||||
throw std::runtime_error("Pointer doesn't look like a NumPy header");
|
||||
} else
|
||||
throw std::runtime_error("Pointer doesn't look like a NumPy header");
|
||||
|
||||
//move passed magic
|
||||
data += 11;
|
||||
unsigned int *shape;
|
||||
|
|
|
@ -2077,7 +2077,8 @@ public class Shape {
|
|||
newStrides[nk] = last_stride;
|
||||
}
|
||||
|
||||
INDArray ret = Nd4j.create(arr.data(), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
|
||||
// we need to wrap buffer of a current array, to make sure it's properly marked as a View
|
||||
INDArray ret = Nd4j.create(Nd4j.createBuffer(arr.data(), arr.offset(), arr.length()), newShape, newStrides, arr.offset(), isFOrder ? 'f' : 'c');
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -6996,14 +6996,14 @@ public class Nd4j {
|
|||
case UTF8: {
|
||||
try {
|
||||
val sb = bb.order(_order);
|
||||
val pos = bb.position();
|
||||
val arr = new byte[sb.limit() - sb.position()];
|
||||
val pos = sb.position();
|
||||
val arr = new byte[sb.limit() - pos];
|
||||
|
||||
for (int e = 0; e < arr.length; e++) {
|
||||
arr[e] = sb.get(e + sb.position());
|
||||
arr[e] = sb.get(e + pos);
|
||||
}
|
||||
|
||||
val buffer = new Utf8Buffer(arr, ArrayUtil.prod(shapeOf));
|
||||
val buffer = new Utf8Buffer(arr, prod);
|
||||
return Nd4j.create(buffer, shapeOf);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
|
|
|
@ -140,7 +140,6 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
|
|||
dataPointer.capacity(dataBufferElementSize * Shape.length(shapeBuffer));
|
||||
|
||||
val jvmShapeInfo = shapeBuffer.asLong();
|
||||
log.info("JVM shapeInfo: {}", jvmShapeInfo);
|
||||
val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);
|
||||
|
||||
switch (dtype) {
|
||||
|
|
|
@ -678,6 +678,7 @@ bool verbose = false;
|
|||
|
||||
// #include <array/ShapeList.h>
|
||||
// #include <array/ConstantDescriptor.h>
|
||||
// #include <helpers/ConstantShapeHelper.h>
|
||||
// #include <array/ConstantDataBuffer.h>
|
||||
// #include <helpers/ConstantHelper.h>
|
||||
// #include <array/TadPack.h>
|
||||
|
@ -2810,6 +2811,44 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
|||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
public native void sortTad(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||
|
@ -2835,6 +2874,56 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
|||
@Cast("Nd4jLong*") long[] tadOffsets,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||
IntPointer dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||
IntBuffer dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortTadByKey(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||
int[] dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongPointer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongPointer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongPointer dyShapeInfo,
|
||||
IntPointer dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") LongBuffer dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") LongBuffer yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") LongBuffer dyShapeInfo,
|
||||
IntBuffer dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
Pointer x, @Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
Pointer dx, @Cast("Nd4jLong*") long[] dxShapeInfo,
|
||||
Pointer y, @Cast("Nd4jLong*") long[] yShapeInfo,
|
||||
Pointer dy, @Cast("Nd4jLong*") long[] dyShapeInfo,
|
||||
int[] dimension,
|
||||
int dimensionLength,
|
||||
@Cast("bool") boolean descending);
|
||||
|
||||
|
||||
// special sort impl for sorting out COO indices and values
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
||||
|
|
|
@ -678,6 +678,7 @@ bool verbose = false;
|
|||
|
||||
// #include <array/ShapeList.h>
|
||||
// #include <array/ConstantDescriptor.h>
|
||||
// #include <helpers/ConstantShapeHelper.h>
|
||||
// #include <array/ConstantDataBuffer.h>
|
||||
// #include <helpers/ConstantHelper.h>
|
||||
// #include <array/TadPack.h>
|
||||
|
|
|
@ -103,6 +103,24 @@ public class StringArrayTests extends BaseNd4jTest {
|
|||
assertEquals("gamma", restored.getString(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicStrings_4a() {
|
||||
val arrayX = Nd4j.scalar("alpha");
|
||||
|
||||
val fb = new FlatBufferBuilder();
|
||||
val i = arrayX.toFlatArray(fb);
|
||||
fb.finish(i);
|
||||
val db = fb.dataBuffer();
|
||||
|
||||
val flat = FlatArray.getRootAsFlatArray(db);
|
||||
val restored = Nd4j.createFromFlatArray(flat);
|
||||
|
||||
assertEquals("alpha", arrayX.getString(0));
|
||||
|
||||
assertEquals(arrayX, restored);
|
||||
assertEquals("alpha", restored.getString(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBasicStrings_5() {
|
||||
val arrayX = Nd4j.create("alpha", "beta", "gamma");
|
||||
|
|
|
@ -31,6 +31,8 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.impl.*;
|
||||
import org.nd4j.linalg.api.rng.DefaultRandom;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
@ -1424,6 +1426,33 @@ public class RandomTests extends BaseNd4jTest {
|
|||
return out;
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testRngRepeatabilityUniform(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
|
||||
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0));
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
|
||||
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0));
|
||||
|
||||
assertEquals(out1, out2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRngRepeatabilityBernoulli(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
|
||||
Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out1, 0.5));
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
|
||||
Nd4j.exec(new RandomBernoulli(Nd4j.createFromArray(10L), out2, 0.5));
|
||||
|
||||
assertEquals(out1, out2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -256,6 +256,54 @@ public class NumpyFormatTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFromNumpyScalar() throws Exception {
|
||||
val out = Nd4j.createFromNpyFile(new ClassPathResource("numpy_oneoff/scalar.npy").getFile());
|
||||
assertEquals(Nd4j.scalar(DataType.INT, 1), out);
|
||||
}
|
||||
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void readNumpyCorruptHeader1() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
|
||||
File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
|
||||
byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
|
||||
for( int i=0; i<10; i++ ){
|
||||
numpyBytes[i] = 0;
|
||||
}
|
||||
File fCorrupt = new File(f, "corrupt.npy");
|
||||
FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);
|
||||
|
||||
INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);
|
||||
|
||||
INDArray act1 = Nd4j.createFromNpyFile(fValid);
|
||||
assertEquals(exp, act1);
|
||||
|
||||
INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
|
||||
boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
|
||||
}
|
||||
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void readNumpyCorruptHeader2() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
|
||||
File fValid = new ClassPathResource("numpy_arrays/arange_3,4_float32.npy").getFile();
|
||||
byte[] numpyBytes = FileUtils.readFileToByteArray(fValid);
|
||||
for( int i=1; i<10; i++ ){
|
||||
numpyBytes[i] = 0;
|
||||
}
|
||||
File fCorrupt = new File(f, "corrupt.npy");
|
||||
FileUtils.writeByteArrayToFile(fCorrupt, numpyBytes);
|
||||
|
||||
INDArray exp = Nd4j.arange(12).castTo(DataType.FLOAT).reshape(3,4);
|
||||
|
||||
INDArray act1 = Nd4j.createFromNpyFile(fValid);
|
||||
assertEquals(exp, act1);
|
||||
|
||||
INDArray probablyShouldntLoad = Nd4j.createFromNpyFile(fCorrupt); //Loads fine
|
||||
boolean eq = exp.equals(probablyShouldntLoad); //And is actually equal content
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -289,6 +289,15 @@ public class EmptyTests extends BaseNd4jTest {
|
|||
assertFalse(Nd4j.create(0).equalShapes(Nd4j.create(1, 0)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyWhere() {
|
||||
val mask = Nd4j.createFromArray(false, false, false, false, false);
|
||||
val result = Nd4j.where(mask, null, null);
|
||||
|
||||
assertTrue(result[0].isEmpty());
|
||||
assertNotNull(result[0].shapeInfoDataBuffer().asLong());
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -29,8 +29,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
|
@ -489,6 +488,17 @@ public class ShapeTestsC extends BaseNd4jTest {
|
|||
assertEquals(exp, reshaped);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testViewAfterReshape() {
|
||||
val x = Nd4j.rand(3,4);
|
||||
val x2 = x.ravel();
|
||||
val x3 = x.reshape(6,2);
|
||||
|
||||
assertFalse(x.isView());
|
||||
assertTrue(x2.isView());
|
||||
assertTrue(x3.isView());
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue