[WIP] Numpy boolean import (#91)

* numpy bool type

Signed-off-by: raver119 <raver119@gmail.com>

* numpy bool java side

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-30 09:59:22 +03:00 committed by AlexDBlack
parent b95417f7c5
commit 065b34c7cb
3 changed files with 25 additions and 5 deletions

View File

@ -119,6 +119,8 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
const auto s = data[si];
switch (t) {
case 'b':
return nd4j::DataType::BOOL;
case 'i':
switch (s) {
case '1': return nd4j::DataType::INT8;
@ -128,7 +130,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import");
}
break;
case 'f':
switch (s) {
case '1': return nd4j::DataType::FLOAT8;
@ -138,7 +139,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import");
}
break;
case 'u':
switch (s) {
case '1': return nd4j::DataType::UINT8;
@ -148,14 +148,11 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
default:
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import");
}
break;
case 'c':
throw std::runtime_error("Import of complex data types isn't supported yet");
default:
throw std::runtime_error("Unknown type marker");
}
return nd4j::DataType::INHERIT;
}
template <typename T>

View File

@ -143,6 +143,20 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);
switch (dtype) {
case BOOL: {
val dPointer = new BooleanPointer(dataPointer.limit() / dataBufferElementSize);
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
Pointer.memcpy(dPointer, dataPointer, dataPointer.limit());
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
data = Nd4j.createBuffer(dPointer,
dtype,
Shape.length(shapeBuffer),
BooleanIndexer.create(dPointer));
}
break;
case UBYTE: {
val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize);
val perfX = PerformanceTracker.getInstance().helperStartTransaction();

View File

@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.util.ArrayUtil;
import java.io.File;
import java.io.FileInputStream;
@ -317,6 +318,14 @@ public class NumpyFormatTests extends BaseNd4jTest {
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
}
@Ignore
@Test
public void testNumpyBoolean() {
INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy"));
System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape())));
System.out.println(out);
}
@Override
public char ordering() {
return 'c';