TF import test resources loading precision fixes (#92)

* Fix precision issues when loading from CSV

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-30 18:58:37 +11:00 committed by GitHub
parent 4fb9fa7748
commit 35ab4a72ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 17 deletions

View File

@ -55,6 +55,14 @@ import java.util.*;
@Slf4j
public class TFGraphMapper {
/**
* @deprecated Use static methods - {@link #importGraph(File)} etc
*/
@Deprecated
public static TFGraphMapper getInstance(){
return new TFGraphMapper();
}
/**
* Import a frozen TensorFlow protobuf (.pb) file from the specified file
*

View File

@ -594,7 +594,7 @@ public class TFGraphTestAllHelper {
val key = modelDir + "/" + okey;
// parse type directly
val value = ArrayOptionsHelper.dataType(split[1]);
DataType value = ArrayOptionsHelper.dataType(split[1]);
// adding key directly
//if (dtypes.containsKey(key))
@ -672,12 +672,35 @@ public class TFGraphTestAllHelper {
INDArray varValue;
if(filtered.size() == 0){
//Scalar
float[] varContents;
try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){
varContents = Nd4j.readNumpy(is, ",").data().asFloat();
String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8);
switch (type){
case DOUBLE:
case FLOAT:
case HALF:
case BFLOAT16:
varValue = Nd4j.scalar(type, parseDouble(content));
break;
case LONG:
case INT:
case SHORT:
case UBYTE:
case BYTE:
case UINT16:
case UINT32:
case UINT64:
varValue = Nd4j.scalar(type, parseLong(content));
break;
case BOOL:
varValue = Nd4j.scalar(parseBoolean(content));
break;
case UTF8:
varValue = Nd4j.scalar(content);
break;
case COMPRESSED:
case UNKNOWN:
default:
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
}
Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length);
varValue = Nd4j.scalar(type, varContents[0]);
} else {
int[] varShape = new int[filtered.size()];
for( int j=0; j<filtered.size(); j++ ){
@ -716,18 +739,55 @@ public class TFGraphTestAllHelper {
throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
}
} else {
content = content.replaceAll("False", "0");
content = content.replaceAll("True", "1");
val varContents = Nd4j.readNumpy(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8)), ",").data().asDouble();
if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of []
varShape = new int[0];
if (varShape.length == 1) {
if (varShape[0] == 0) {
varValue = Nd4j.scalar(type, varContents[0]);
} else {
varValue = Nd4j.create(varContents, new long[]{varContents.length}, type);
}
} else {
varValue = Nd4j.create(varContents, ArrayUtil.toLongArray(varShape), type);
String[] cLines = content.split("\n");
switch (type){
case DOUBLE:
case FLOAT:
case HALF:
case BFLOAT16:
double[] dArr = new double[cLines.length];
int x=0;
while(x < dArr.length){
dArr[x] = parseDouble(cLines[x]);
x++;
}
varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape);
break;
case LONG:
case INT:
case SHORT:
case UBYTE:
case BYTE:
case UINT16:
case UINT32:
case UINT64:
long[] lArr = new long[cLines.length];
int y=0;
while(y < lArr.length){
lArr[y] = parseLong(cLines[y]);
y++;
}
varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape);
break;
case BOOL:
boolean[] bArr = new boolean[cLines.length];
int z=0;
while(z < bArr.length){
bArr[z] = parseBoolean(cLines[z]);
z++;
}
varValue = Nd4j.createFromArray(bArr).reshape('c', varShape);
break;
case UTF8:
varValue = Nd4j.create(cLines).reshape('c', varShape);
break;
case COMPRESSED:
case UNKNOWN:
default:
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
}
}
} catch (NumberFormatException e) {
@ -741,6 +801,39 @@ public class TFGraphTestAllHelper {
return varMap;
}
private static long parseLong(String line){
line = line.trim(); //Handle whitespace
if(line.matches("-?\\d+\\.0+")){
//Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000"
return Long.parseLong(line.substring(0, line.indexOf('.')));
} else {
return Long.parseLong(line);
}
}
private static double parseDouble(String line){
line = line.trim(); //Handle whitespace - some lines are like " -inf"
if("nan".equalsIgnoreCase(line)){
return Double.NaN;
} else if("inf".equalsIgnoreCase(line)) {
return Double.POSITIVE_INFINITY;
} else if("-inf".equalsIgnoreCase(line)){
return Double.NEGATIVE_INFINITY;
} else {
return Double.parseDouble(line);
}
}
private static boolean parseBoolean(String line){
line = line.trim();
if(line.matches("1(\\.0*)?")){ //Booleans are ocassionally represented like 1.000000 or 0.000000
return true;
} else if(line.matches("0(\\.0*)?")){
return false;
}
return Boolean.parseBoolean(line);
}
public static Pair<Double,Double> testPrecisionOverride(String testName){
if("conv_4".equalsIgnoreCase(testName)){