TF import test resources loading precision fixes (#92)
* Fix precision issues when loading from CSV Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
4fb9fa7748
commit
35ab4a72ba
|
@ -55,6 +55,14 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TFGraphMapper {
|
public class TFGraphMapper {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use static methods - {@link #importGraph(File)} etc
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public static TFGraphMapper getInstance(){
|
||||||
|
return new TFGraphMapper();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Import a frozen TensorFlow protobuf (.pb) file from the specified file
|
* Import a frozen TensorFlow protobuf (.pb) file from the specified file
|
||||||
*
|
*
|
||||||
|
|
|
@ -594,7 +594,7 @@ public class TFGraphTestAllHelper {
|
||||||
val key = modelDir + "/" + okey;
|
val key = modelDir + "/" + okey;
|
||||||
|
|
||||||
// parse type directly
|
// parse type directly
|
||||||
val value = ArrayOptionsHelper.dataType(split[1]);
|
DataType value = ArrayOptionsHelper.dataType(split[1]);
|
||||||
|
|
||||||
// adding key directly
|
// adding key directly
|
||||||
//if (dtypes.containsKey(key))
|
//if (dtypes.containsKey(key))
|
||||||
|
@ -672,12 +672,35 @@ public class TFGraphTestAllHelper {
|
||||||
INDArray varValue;
|
INDArray varValue;
|
||||||
if(filtered.size() == 0){
|
if(filtered.size() == 0){
|
||||||
//Scalar
|
//Scalar
|
||||||
float[] varContents;
|
String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8);
|
||||||
try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){
|
switch (type){
|
||||||
varContents = Nd4j.readNumpy(is, ",").data().asFloat();
|
case DOUBLE:
|
||||||
|
case FLOAT:
|
||||||
|
case HALF:
|
||||||
|
case BFLOAT16:
|
||||||
|
varValue = Nd4j.scalar(type, parseDouble(content));
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
case INT:
|
||||||
|
case SHORT:
|
||||||
|
case UBYTE:
|
||||||
|
case BYTE:
|
||||||
|
case UINT16:
|
||||||
|
case UINT32:
|
||||||
|
case UINT64:
|
||||||
|
varValue = Nd4j.scalar(type, parseLong(content));
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
varValue = Nd4j.scalar(parseBoolean(content));
|
||||||
|
break;
|
||||||
|
case UTF8:
|
||||||
|
varValue = Nd4j.scalar(content);
|
||||||
|
break;
|
||||||
|
case COMPRESSED:
|
||||||
|
case UNKNOWN:
|
||||||
|
default:
|
||||||
|
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
|
||||||
}
|
}
|
||||||
Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length);
|
|
||||||
varValue = Nd4j.scalar(type, varContents[0]);
|
|
||||||
} else {
|
} else {
|
||||||
int[] varShape = new int[filtered.size()];
|
int[] varShape = new int[filtered.size()];
|
||||||
for( int j=0; j<filtered.size(); j++ ){
|
for( int j=0; j<filtered.size(); j++ ){
|
||||||
|
@ -716,18 +739,55 @@ public class TFGraphTestAllHelper {
|
||||||
throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
|
throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
content = content.replaceAll("False", "0");
|
if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of []
|
||||||
content = content.replaceAll("True", "1");
|
varShape = new int[0];
|
||||||
val varContents = Nd4j.readNumpy(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8)), ",").data().asDouble();
|
|
||||||
|
|
||||||
if (varShape.length == 1) {
|
String[] cLines = content.split("\n");
|
||||||
if (varShape[0] == 0) {
|
switch (type){
|
||||||
varValue = Nd4j.scalar(type, varContents[0]);
|
case DOUBLE:
|
||||||
} else {
|
case FLOAT:
|
||||||
varValue = Nd4j.create(varContents, new long[]{varContents.length}, type);
|
case HALF:
|
||||||
|
case BFLOAT16:
|
||||||
|
double[] dArr = new double[cLines.length];
|
||||||
|
int x=0;
|
||||||
|
while(x < dArr.length){
|
||||||
|
dArr[x] = parseDouble(cLines[x]);
|
||||||
|
x++;
|
||||||
}
|
}
|
||||||
} else {
|
varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape);
|
||||||
varValue = Nd4j.create(varContents, ArrayUtil.toLongArray(varShape), type);
|
break;
|
||||||
|
case LONG:
|
||||||
|
case INT:
|
||||||
|
case SHORT:
|
||||||
|
case UBYTE:
|
||||||
|
case BYTE:
|
||||||
|
case UINT16:
|
||||||
|
case UINT32:
|
||||||
|
case UINT64:
|
||||||
|
long[] lArr = new long[cLines.length];
|
||||||
|
int y=0;
|
||||||
|
while(y < lArr.length){
|
||||||
|
lArr[y] = parseLong(cLines[y]);
|
||||||
|
y++;
|
||||||
|
}
|
||||||
|
varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape);
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
boolean[] bArr = new boolean[cLines.length];
|
||||||
|
int z=0;
|
||||||
|
while(z < bArr.length){
|
||||||
|
bArr[z] = parseBoolean(cLines[z]);
|
||||||
|
z++;
|
||||||
|
}
|
||||||
|
varValue = Nd4j.createFromArray(bArr).reshape('c', varShape);
|
||||||
|
break;
|
||||||
|
case UTF8:
|
||||||
|
varValue = Nd4j.create(cLines).reshape('c', varShape);
|
||||||
|
break;
|
||||||
|
case COMPRESSED:
|
||||||
|
case UNKNOWN:
|
||||||
|
default:
|
||||||
|
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (NumberFormatException e) {
|
} catch (NumberFormatException e) {
|
||||||
|
@ -741,6 +801,39 @@ public class TFGraphTestAllHelper {
|
||||||
return varMap;
|
return varMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static long parseLong(String line){
|
||||||
|
line = line.trim(); //Handle whitespace
|
||||||
|
if(line.matches("-?\\d+\\.0+")){
|
||||||
|
//Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000"
|
||||||
|
return Long.parseLong(line.substring(0, line.indexOf('.')));
|
||||||
|
} else {
|
||||||
|
return Long.parseLong(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double parseDouble(String line){
|
||||||
|
line = line.trim(); //Handle whitespace - some lines are like " -inf"
|
||||||
|
if("nan".equalsIgnoreCase(line)){
|
||||||
|
return Double.NaN;
|
||||||
|
} else if("inf".equalsIgnoreCase(line)) {
|
||||||
|
return Double.POSITIVE_INFINITY;
|
||||||
|
} else if("-inf".equalsIgnoreCase(line)){
|
||||||
|
return Double.NEGATIVE_INFINITY;
|
||||||
|
} else {
|
||||||
|
return Double.parseDouble(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean parseBoolean(String line){
|
||||||
|
line = line.trim();
|
||||||
|
if(line.matches("1(\\.0*)?")){ //Booleans are ocassionally represented like 1.000000 or 0.000000
|
||||||
|
return true;
|
||||||
|
} else if(line.matches("0(\\.0*)?")){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return Boolean.parseBoolean(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static Pair<Double,Double> testPrecisionOverride(String testName){
|
public static Pair<Double,Double> testPrecisionOverride(String testName){
|
||||||
if("conv_4".equalsIgnoreCase(testName)){
|
if("conv_4".equalsIgnoreCase(testName)){
|
||||||
|
|
Loading…
Reference in New Issue