TF import test resources loading precision fixes (#92)
* Fix precision issues when loading from CSV Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
4fb9fa7748
commit
35ab4a72ba
|
@ -55,6 +55,14 @@ import java.util.*;
|
|||
@Slf4j
|
||||
public class TFGraphMapper {
|
||||
|
||||
/**
|
||||
* @deprecated Use static methods - {@link #importGraph(File)} etc
|
||||
*/
|
||||
@Deprecated
|
||||
public static TFGraphMapper getInstance(){
|
||||
return new TFGraphMapper();
|
||||
}
|
||||
|
||||
/**
|
||||
* Import a frozen TensorFlow protobuf (.pb) file from the specified file
|
||||
*
|
||||
|
|
|
@ -594,7 +594,7 @@ public class TFGraphTestAllHelper {
|
|||
val key = modelDir + "/" + okey;
|
||||
|
||||
// parse type directly
|
||||
val value = ArrayOptionsHelper.dataType(split[1]);
|
||||
DataType value = ArrayOptionsHelper.dataType(split[1]);
|
||||
|
||||
// adding key directly
|
||||
//if (dtypes.containsKey(key))
|
||||
|
@ -672,12 +672,35 @@ public class TFGraphTestAllHelper {
|
|||
INDArray varValue;
|
||||
if(filtered.size() == 0){
|
||||
//Scalar
|
||||
float[] varContents;
|
||||
try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){
|
||||
varContents = Nd4j.readNumpy(is, ",").data().asFloat();
|
||||
String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8);
|
||||
switch (type){
|
||||
case DOUBLE:
|
||||
case FLOAT:
|
||||
case HALF:
|
||||
case BFLOAT16:
|
||||
varValue = Nd4j.scalar(type, parseDouble(content));
|
||||
break;
|
||||
case LONG:
|
||||
case INT:
|
||||
case SHORT:
|
||||
case UBYTE:
|
||||
case BYTE:
|
||||
case UINT16:
|
||||
case UINT32:
|
||||
case UINT64:
|
||||
varValue = Nd4j.scalar(type, parseLong(content));
|
||||
break;
|
||||
case BOOL:
|
||||
varValue = Nd4j.scalar(parseBoolean(content));
|
||||
break;
|
||||
case UTF8:
|
||||
varValue = Nd4j.scalar(content);
|
||||
break;
|
||||
case COMPRESSED:
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
|
||||
}
|
||||
Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length);
|
||||
varValue = Nd4j.scalar(type, varContents[0]);
|
||||
} else {
|
||||
int[] varShape = new int[filtered.size()];
|
||||
for( int j=0; j<filtered.size(); j++ ){
|
||||
|
@ -716,18 +739,55 @@ public class TFGraphTestAllHelper {
|
|||
throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
|
||||
}
|
||||
} else {
|
||||
content = content.replaceAll("False", "0");
|
||||
content = content.replaceAll("True", "1");
|
||||
val varContents = Nd4j.readNumpy(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8)), ",").data().asDouble();
|
||||
if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of []
|
||||
varShape = new int[0];
|
||||
|
||||
if (varShape.length == 1) {
|
||||
if (varShape[0] == 0) {
|
||||
varValue = Nd4j.scalar(type, varContents[0]);
|
||||
} else {
|
||||
varValue = Nd4j.create(varContents, new long[]{varContents.length}, type);
|
||||
}
|
||||
} else {
|
||||
varValue = Nd4j.create(varContents, ArrayUtil.toLongArray(varShape), type);
|
||||
String[] cLines = content.split("\n");
|
||||
switch (type){
|
||||
case DOUBLE:
|
||||
case FLOAT:
|
||||
case HALF:
|
||||
case BFLOAT16:
|
||||
double[] dArr = new double[cLines.length];
|
||||
int x=0;
|
||||
while(x < dArr.length){
|
||||
dArr[x] = parseDouble(cLines[x]);
|
||||
x++;
|
||||
}
|
||||
varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape);
|
||||
break;
|
||||
case LONG:
|
||||
case INT:
|
||||
case SHORT:
|
||||
case UBYTE:
|
||||
case BYTE:
|
||||
case UINT16:
|
||||
case UINT32:
|
||||
case UINT64:
|
||||
long[] lArr = new long[cLines.length];
|
||||
int y=0;
|
||||
while(y < lArr.length){
|
||||
lArr[y] = parseLong(cLines[y]);
|
||||
y++;
|
||||
}
|
||||
varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape);
|
||||
break;
|
||||
case BOOL:
|
||||
boolean[] bArr = new boolean[cLines.length];
|
||||
int z=0;
|
||||
while(z < bArr.length){
|
||||
bArr[z] = parseBoolean(cLines[z]);
|
||||
z++;
|
||||
}
|
||||
varValue = Nd4j.createFromArray(bArr).reshape('c', varShape);
|
||||
break;
|
||||
case UTF8:
|
||||
varValue = Nd4j.create(cLines).reshape('c', varShape);
|
||||
break;
|
||||
case COMPRESSED:
|
||||
case UNKNOWN:
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
|
||||
}
|
||||
}
|
||||
} catch (NumberFormatException e) {
|
||||
|
@ -741,6 +801,39 @@ public class TFGraphTestAllHelper {
|
|||
return varMap;
|
||||
}
|
||||
|
||||
private static long parseLong(String line){
|
||||
line = line.trim(); //Handle whitespace
|
||||
if(line.matches("-?\\d+\\.0+")){
|
||||
//Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000"
|
||||
return Long.parseLong(line.substring(0, line.indexOf('.')));
|
||||
} else {
|
||||
return Long.parseLong(line);
|
||||
}
|
||||
}
|
||||
|
||||
private static double parseDouble(String line){
|
||||
line = line.trim(); //Handle whitespace - some lines are like " -inf"
|
||||
if("nan".equalsIgnoreCase(line)){
|
||||
return Double.NaN;
|
||||
} else if("inf".equalsIgnoreCase(line)) {
|
||||
return Double.POSITIVE_INFINITY;
|
||||
} else if("-inf".equalsIgnoreCase(line)){
|
||||
return Double.NEGATIVE_INFINITY;
|
||||
} else {
|
||||
return Double.parseDouble(line);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean parseBoolean(String line){
|
||||
line = line.trim();
|
||||
if(line.matches("1(\\.0*)?")){ //Booleans are ocassionally represented like 1.000000 or 0.000000
|
||||
return true;
|
||||
} else if(line.matches("0(\\.0*)?")){
|
||||
return false;
|
||||
}
|
||||
return Boolean.parseBoolean(line);
|
||||
}
|
||||
|
||||
|
||||
public static Pair<Double,Double> testPrecisionOverride(String testName){
|
||||
if("conv_4".equalsIgnoreCase(testName)){
|
||||
|
|
Loading…
Reference in New Issue