[WIP] Numpy boolean import (#91)

* numpy bool type

Signed-off-by: raver119 <raver119@gmail.com>

* numpy bool java side

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-30 09:59:22 +03:00 committed by AlexDBlack
parent b95417f7c5
commit 065b34c7cb
3 changed files with 25 additions and 5 deletions

View File

@ -119,6 +119,8 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
const auto s = data[si]; const auto s = data[si];
switch (t) { switch (t) {
case 'b':
return nd4j::DataType::BOOL;
case 'i': case 'i':
switch (s) { switch (s) {
case '1': return nd4j::DataType::INT8; case '1': return nd4j::DataType::INT8;
@ -128,7 +130,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default: default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import"); throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import");
} }
break;
case 'f': case 'f':
switch (s) { switch (s) {
case '1': return nd4j::DataType::FLOAT8; case '1': return nd4j::DataType::FLOAT8;
@ -138,7 +139,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default: default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import"); throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import");
} }
break;
case 'u': case 'u':
switch (s) { switch (s) {
case '1': return nd4j::DataType::UINT8; case '1': return nd4j::DataType::UINT8;
@ -148,14 +148,11 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default: default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import"); throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import");
} }
break;
case 'c': case 'c':
throw std::runtime_error("Import of complex data types isn't supported yet"); throw std::runtime_error("Import of complex data types isn't supported yet");
default: default:
throw std::runtime_error("Unknown type marker"); throw std::runtime_error("Unknown type marker");
} }
return nd4j::DataType::INHERIT;
} }
template <typename T> template <typename T>

View File

@ -143,6 +143,20 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo); val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);
switch (dtype) { switch (dtype) {
case BOOL: {
val dPointer = new BooleanPointer(dataPointer.limit() / dataBufferElementSize);
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
Pointer.memcpy(dPointer, dataPointer, dataPointer.limit());
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
data = Nd4j.createBuffer(dPointer,
dtype,
Shape.length(shapeBuffer),
BooleanIndexer.create(dPointer));
}
break;
case UBYTE: { case UBYTE: {
val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize); val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize);
val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val perfX = PerformanceTracker.getInstance().helperStartTransaction();

View File

@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.util.ArrayUtil;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -317,6 +318,14 @@ public class NumpyFormatTests extends BaseNd4jTest {
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
} }
@Ignore
@Test
public void testNumpyBoolean() {
INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy"));
System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape())));
System.out.println(out);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';