[WIP] Numpy boolean import (#91)
* numpy bool type Signed-off-by: raver119 <raver119@gmail.com> * numpy bool java side Signed-off-by: raver119 <raver119@gmail.com>master
parent
b95417f7c5
commit
065b34c7cb
|
@ -119,6 +119,8 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
|
|||
const auto s = data[si];
|
||||
|
||||
switch (t) {
|
||||
case 'b':
|
||||
return nd4j::DataType::BOOL;
|
||||
case 'i':
|
||||
switch (s) {
|
||||
case '1': return nd4j::DataType::INT8;
|
||||
|
@ -128,7 +130,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
|
|||
default:
|
||||
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Integer data types import");
|
||||
}
|
||||
break;
|
||||
case 'f':
|
||||
switch (s) {
|
||||
case '1': return nd4j::DataType::FLOAT8;
|
||||
|
@ -138,7 +139,6 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
|
|||
default:
|
||||
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Float data types import");
|
||||
}
|
||||
break;
|
||||
case 'u':
|
||||
switch (s) {
|
||||
case '1': return nd4j::DataType::UINT8;
|
||||
|
@ -148,14 +148,11 @@ nd4j::DataType cnpy::dataTypeFromHeader(char *data) {
|
|||
default:
|
||||
throw std::runtime_error("Only data sizes of [1, 2, 4, 8] are supported for Unsigned data types import");
|
||||
}
|
||||
break;
|
||||
case 'c':
|
||||
throw std::runtime_error("Import of complex data types isn't supported yet");
|
||||
default:
|
||||
throw std::runtime_error("Unknown type marker");
|
||||
}
|
||||
|
||||
return nd4j::DataType::INHERIT;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -143,6 +143,20 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
|
|||
val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo);
|
||||
|
||||
switch (dtype) {
|
||||
case BOOL: {
|
||||
val dPointer = new BooleanPointer(dataPointer.limit() / dataBufferElementSize);
|
||||
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
||||
Pointer.memcpy(dPointer, dataPointer, dataPointer.limit());
|
||||
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST);
|
||||
|
||||
data = Nd4j.createBuffer(dPointer,
|
||||
dtype,
|
||||
Shape.length(shapeBuffer),
|
||||
BooleanIndexer.create(dPointer));
|
||||
}
|
||||
break;
|
||||
case UBYTE: {
|
||||
val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize);
|
||||
val perfX = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
@ -317,6 +318,14 @@ public class NumpyFormatTests extends BaseNd4jTest {
|
|||
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testNumpyBoolean() {
|
||||
INDArray out = Nd4j.createFromNpyFile(new File("c:/Users/raver/Downloads/error2.npy"));
|
||||
System.out.println(ArrayUtil.toList(ArrayUtil.toInts(out.shape())));
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue