parent
d9cfa8073f
commit
fec620fafa
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.BaseND4JTest;
|
import org.nd4j.BaseND4JTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
@ -29,7 +31,9 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.fail;
|
import static org.junit.Assert.fail;
|
||||||
|
import static org.nd4j.linalg.api.buffer.DataType.*;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class TensorflowConversionTest extends BaseND4JTest {
|
public class TensorflowConversionTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -53,15 +57,39 @@ public class TensorflowConversionTest extends BaseND4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConversionFromNdArray() throws Exception {
|
public void testConversionFromNdArray() throws Exception {
|
||||||
INDArray arr = Nd4j.linspace(1,4,4);
|
DataType[] dtypes = new DataType[]{
|
||||||
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
DOUBLE,
|
||||||
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
FLOAT,
|
||||||
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
SHORT,
|
||||||
assertEquals(arr,fromTensor);
|
LONG,
|
||||||
arr.addi(1.0);
|
BYTE,
|
||||||
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
UBYTE,
|
||||||
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
UINT16,
|
||||||
assertEquals(arr,fromTensor);
|
UINT32,
|
||||||
|
UINT64,
|
||||||
|
BFLOAT16,
|
||||||
|
BOOL,
|
||||||
|
INT,
|
||||||
|
HALF
|
||||||
|
};
|
||||||
|
for(DataType dtype: dtypes){
|
||||||
|
log.debug("Testing conversion for data type " + dtype);
|
||||||
|
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2).castTo(dtype);
|
||||||
|
TensorflowConversion tensorflowConversion =TensorflowConversion.getInstance();
|
||||||
|
TF_Tensor tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||||
|
INDArray fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||||
|
assertEquals(arr,fromTensor);
|
||||||
|
if (dtype == BOOL){
|
||||||
|
arr.putScalar(3, 0);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
arr.addi(1.0);
|
||||||
|
}
|
||||||
|
tf_tensor = tensorflowConversion.tensorFromNDArray(arr);
|
||||||
|
fromTensor = tensorflowConversion.ndArrayFromTensor(tf_tensor);
|
||||||
|
assertEquals(arr,fromTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,8 +121,16 @@ public class TensorflowConversion {
|
||||||
default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
|
default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case SHORT: type = DT_INT16; break;
|
||||||
case LONG: type = DT_INT64; break;
|
case LONG: type = DT_INT64; break;
|
||||||
case UTF8: type = DT_STRING; break;
|
case UTF8: type = DT_STRING; break;
|
||||||
|
case BYTE: type = DT_INT8; break;
|
||||||
|
case UBYTE: type = DT_UINT8; break;
|
||||||
|
case UINT16: type = DT_UINT16; break;
|
||||||
|
case UINT32: type = DT_UINT32; break;
|
||||||
|
case UINT64: type = DT_UINT64; break;
|
||||||
|
case BFLOAT16: type = DT_BFLOAT16; break;
|
||||||
|
case BOOL: type = DT_BOOL; break;
|
||||||
default: throw new IllegalArgumentException("Unsupported data type: " + dataType);
|
default: throw new IllegalArgumentException("Unsupported data type: " + dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,6 +258,15 @@ public class TensorflowConversion {
|
||||||
case FLOAT: return FloatIndexer.create(new FloatPointer(pointer));
|
case FLOAT: return FloatIndexer.create(new FloatPointer(pointer));
|
||||||
case INT: return IntIndexer.create(new IntPointer(pointer));
|
case INT: return IntIndexer.create(new IntPointer(pointer));
|
||||||
case LONG: return LongIndexer.create(new LongPointer(pointer));
|
case LONG: return LongIndexer.create(new LongPointer(pointer));
|
||||||
|
case SHORT: return ShortIndexer.create(new ShortPointer(pointer));
|
||||||
|
case BYTE: return ByteIndexer.create(new BytePointer(pointer));
|
||||||
|
case UBYTE: return UByteIndexer.create(new BytePointer(pointer));
|
||||||
|
case UINT16: return UShortIndexer.create(new ShortPointer(pointer));
|
||||||
|
case UINT32: return UIntIndexer.create(new IntPointer(pointer));
|
||||||
|
case UINT64: return ULongIndexer.create(new LongPointer(pointer));
|
||||||
|
case BFLOAT16: return Bfloat16Indexer.create(new ShortPointer(pointer));
|
||||||
|
case HALF: return HalfIndexer.create(new ShortPointer(pointer));
|
||||||
|
case BOOL: return BooleanIndexer.create(new BooleanPointer(pointer));
|
||||||
default: throw new IllegalArgumentException("Illegal type " + type);
|
default: throw new IllegalArgumentException("Illegal type " + type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -258,9 +275,18 @@ public class TensorflowConversion {
|
||||||
switch(tensorflowType) {
|
switch(tensorflowType) {
|
||||||
case DT_DOUBLE: return DataType.DOUBLE;
|
case DT_DOUBLE: return DataType.DOUBLE;
|
||||||
case DT_FLOAT: return DataType.FLOAT;
|
case DT_FLOAT: return DataType.FLOAT;
|
||||||
case DT_INT32: return DataType.LONG;
|
case DT_HALF: return DataType.HALF;
|
||||||
|
case DT_INT16: return DataType.SHORT;
|
||||||
|
case DT_INT32: return DataType.INT;
|
||||||
case DT_INT64: return DataType.LONG;
|
case DT_INT64: return DataType.LONG;
|
||||||
case DT_STRING: return DataType.UTF8;
|
case DT_STRING: return DataType.UTF8;
|
||||||
|
case DT_INT8: return DataType.BYTE;
|
||||||
|
case DT_UINT8: return DataType.UBYTE;
|
||||||
|
case DT_UINT16: return DataType.UINT16;
|
||||||
|
case DT_UINT32: return DataType.UINT32;
|
||||||
|
case DT_UINT64: return DataType.UINT64;
|
||||||
|
case DT_BFLOAT16: return DataType.BFLOAT16;
|
||||||
|
case DT_BOOL: return DataType.BOOL;
|
||||||
default: throw new IllegalArgumentException("Illegal type " + tensorflowType);
|
default: throw new IllegalArgumentException("Illegal type " + tensorflowType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue