TensorflowConversion Data Types (#284)

* dtypes

* bf16 and bool

* tests
master
Fariz Rahman 2020-03-04 04:46:32 +04:00 committed by GitHub
parent d9cfa8073f
commit fec620fafa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 10 deletions

View File

@ -16,9 +16,11 @@
package org.nd4j.tensorflow.conversion;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.junit.Test;
import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
@ -29,7 +31,9 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.nd4j.linalg.api.buffer.DataType.*;
@Slf4j
public class TensorflowConversionTest extends BaseND4JTest {
@Test
@ -53,15 +57,39 @@ public class TensorflowConversionTest extends BaseND4JTest {
@Test
public void testConversionFromNdArray() throws Exception {
INDArray arr = Nd4j.linspace(1,4,4);
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
assertEquals(arr,fromTensor);
arr.addi(1.0);
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
assertEquals(arr,fromTensor);
DataType[] dtypes = new DataType[]{
DOUBLE,
FLOAT,
SHORT,
LONG,
BYTE,
UBYTE,
UINT16,
UINT32,
UINT64,
BFLOAT16,
BOOL,
INT,
HALF
};
for(DataType dtype: dtypes){
log.debug("Testing conversion for data type " + dtype);
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype);
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
assertEquals(arr,fromTensor);
if (dtype == BOOL){
arr.putScalar(3, 0);
}
else{
arr.addi(1.0);
}
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
assertEquals(arr,fromTensor);
}
}

View File

@ -121,8 +121,16 @@ public class TensorflowConversion {
default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
}
break;
case SHORT: type = DT_INT16; break;
case LONG: type = DT_INT64; break;
case UTF8: type = DT_STRING; break;
case BYTE: type = DT_INT8; break;
case UBYTE: type = DT_UINT8; break;
case UINT16: type = DT_UINT16; break;
case UINT32: type = DT_UINT32; break;
case UINT64: type = DT_UINT64; break;
case BFLOAT16: type = DT_BFLOAT16; break;
case BOOL: type = DT_BOOL; break;
default: throw new IllegalArgumentException("Unsupported data type: " + dataType);
}
@ -250,6 +258,15 @@ public class TensorflowConversion {
case FLOAT: return FloatIndexer.create(new FloatPointer(pointer));
case INT: return IntIndexer.create(new IntPointer(pointer));
case LONG: return LongIndexer.create(new LongPointer(pointer));
case SHORT: return ShortIndexer.create(new ShortPointer(pointer));
case BYTE: return ByteIndexer.create(new BytePointer(pointer));
case UBYTE: return UByteIndexer.create(new BytePointer(pointer));
case UINT16: return UShortIndexer.create(new ShortPointer(pointer));
case UINT32: return UIntIndexer.create(new IntPointer(pointer));
case UINT64: return ULongIndexer.create(new LongPointer(pointer));
case BFLOAT16: return Bfloat16Indexer.create(new ShortPointer(pointer));
case HALF: return HalfIndexer.create(new ShortPointer(pointer));
case BOOL: return BooleanIndexer.create(new BooleanPointer(pointer));
default: throw new IllegalArgumentException("Illegal type " + type);
}
}
@ -258,9 +275,18 @@ public class TensorflowConversion {
switch(tensorflowType) {
case DT_DOUBLE: return DataType.DOUBLE;
case DT_FLOAT: return DataType.FLOAT;
case DT_INT32: return DataType.LONG;
case DT_HALF: return DataType.HALF;
case DT_INT16: return DataType.SHORT;
case DT_INT32: return DataType.INT;
case DT_INT64: return DataType.LONG;
case DT_STRING: return DataType.UTF8;
case DT_INT8: return DataType.BYTE;
case DT_UINT8: return DataType.UBYTE;
case DT_UINT16: return DataType.UINT16;
case DT_UINT32: return DataType.UINT32;
case DT_UINT64: return DataType.UINT64;
case DT_BFLOAT16: return DataType.BFLOAT16;
case DT_BOOL: return DataType.BOOL;
default: throw new IllegalArgumentException("Illegal type " + tensorflowType);
}
}