diff --git a/libnd4j/include/cnpy/cnpy.cpp b/libnd4j/include/cnpy/cnpy.cpp index 79fe4cdb9..ccc8f7600 100644 --- a/libnd4j/include/cnpy/cnpy.cpp +++ b/libnd4j/include/cnpy/cnpy.cpp @@ -119,6 +119,8 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) { const auto s = data[si]; switch (t) { + case 'b': + return nd4j::DataType::BOOL; case 'i': switch (s) { case '1': return nd4j::DataType::INT8; @@ -128,7 +130,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) { default: throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import"); } - break; case 'f': switch (s) { case '1': return nd4j::DataType::FLOAT8; @@ -138,7 +139,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) { default: throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import"); } - break; case 'u': switch (s) { case '1': return nd4j::DataType::UINT8; @@ -148,14 +148,11 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) { default: throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import"); } - break; case 'c': throw std::runtime_error("Import of complex data types isn't supported yet"); default: throw std::runtime_error("Unknown type marker"); } - - return nd4j::DataType::INHERIT; } template diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index e7bd7404c..153183f57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -143,6 +143,20 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo); switch (dtype) { + case BOOL: { + val dPointer = new BooleanPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + BooleanIndexer.create(dPointer)); + } + break; case UBYTE: { val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize); val perfX = PerformanceTracker.getInstance().helperStartTransaction(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 8b3e98f2a..164760dc0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.linalg.util.ArrayUtil; import java.io.File; import java.io.FileInputStream; @@ -317,6 +318,14 @@ public class NumpyFormatTests extends BaseNd4jTest { log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue()); } + @Ignore + @Test + public void testNumpyBoolean() { + INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy")); + System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape()))); + System.out.println(out); + } + @Override public char ordering() { return 'c';