parent
d9cfa8073f
commit
fec620fafa
|
@ -16,9 +16,11 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
@ -29,7 +31,9 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
|
|||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.nd4j.linalg.api.buffer.DataType.*;
|
||||
|
||||
@Slf4j
|
||||
public class TensorflowConversionTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
@ -53,15 +57,39 @@ public class TensorflowConversionTest extends BaseND4JTest {
|
|||
|
||||
@Test
|
||||
public void testConversionFromNdArray() throws Exception {
|
||||
INDArray arr = Nd4j.linspace(1,4,4);
|
||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
||||
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||
assertEquals(arr,fromTensor);
|
||||
arr.addi(1.0);
|
||||
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||
assertEquals(arr,fromTensor);
|
||||
DataType[] dtypes = new DataType[]{
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
SHORT,
|
||||
LONG,
|
||||
BYTE,
|
||||
UBYTE,
|
||||
UINT16,
|
||||
UINT32,
|
||||
UINT64,
|
||||
BFLOAT16,
|
||||
BOOL,
|
||||
INT,
|
||||
HALF
|
||||
};
|
||||
for(DataType dtype: dtypes){
|
||||
log.debug("Testing conversion for data type " + dtype);
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype);
|
||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
||||
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||
assertEquals(arr,fromTensor);
|
||||
if (dtype == BOOL){
|
||||
arr.putScalar(3, 0);
|
||||
}
|
||||
else{
|
||||
arr.addi(1.0);
|
||||
}
|
||||
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||
assertEquals(arr,fromTensor);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -121,8 +121,16 @@ public class TensorflowConversion {
|
|||
default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
|
||||
}
|
||||
break;
|
||||
case SHORT: type = DT_INT16; break;
|
||||
case LONG: type = DT_INT64; break;
|
||||
case UTF8: type = DT_STRING; break;
|
||||
case BYTE: type = DT_INT8; break;
|
||||
case UBYTE: type = DT_UINT8; break;
|
||||
case UINT16: type = DT_UINT16; break;
|
||||
case UINT32: type = DT_UINT32; break;
|
||||
case UINT64: type = DT_UINT64; break;
|
||||
case BFLOAT16: type = DT_BFLOAT16; break;
|
||||
case BOOL: type = DT_BOOL; break;
|
||||
default: throw new IllegalArgumentException("Unsupported data type: " + dataType);
|
||||
}
|
||||
|
||||
|
@ -250,6 +258,15 @@ public class TensorflowConversion {
|
|||
case FLOAT: return FloatIndexer.create(new FloatPointer(pointer));
|
||||
case INT: return IntIndexer.create(new IntPointer(pointer));
|
||||
case LONG: return LongIndexer.create(new LongPointer(pointer));
|
||||
case SHORT: return ShortIndexer.create(new ShortPointer(pointer));
|
||||
case BYTE: return ByteIndexer.create(new BytePointer(pointer));
|
||||
case UBYTE: return UByteIndexer.create(new BytePointer(pointer));
|
||||
case UINT16: return UShortIndexer.create(new ShortPointer(pointer));
|
||||
case UINT32: return UIntIndexer.create(new IntPointer(pointer));
|
||||
case UINT64: return ULongIndexer.create(new LongPointer(pointer));
|
||||
case BFLOAT16: return Bfloat16Indexer.create(new ShortPointer(pointer));
|
||||
case HALF: return HalfIndexer.create(new ShortPointer(pointer));
|
||||
case BOOL: return BooleanIndexer.create(new BooleanPointer(pointer));
|
||||
default: throw new IllegalArgumentException("Illegal type " + type);
|
||||
}
|
||||
}
|
||||
|
@ -258,9 +275,18 @@ public class TensorflowConversion {
|
|||
switch(tensorflowType) {
|
||||
case DT_DOUBLE: return DataType.DOUBLE;
|
||||
case DT_FLOAT: return DataType.FLOAT;
|
||||
case DT_INT32: return DataType.LONG;
|
||||
case DT_HALF: return DataType.HALF;
|
||||
case DT_INT16: return DataType.SHORT;
|
||||
case DT_INT32: return DataType.INT;
|
||||
case DT_INT64: return DataType.LONG;
|
||||
case DT_STRING: return DataType.UTF8;
|
||||
case DT_INT8: return DataType.BYTE;
|
||||
case DT_UINT8: return DataType.UBYTE;
|
||||
case DT_UINT16: return DataType.UINT16;
|
||||
case DT_UINT32: return DataType.UINT32;
|
||||
case DT_UINT64: return DataType.UINT64;
|
||||
case DT_BFLOAT16: return DataType.BFLOAT16;
|
||||
case DT_BOOL: return DataType.BOOL;
|
||||
default: throw new IllegalArgumentException("Illegal type " + tensorflowType);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue