TF import test resources loading precision fixes (#92)
* Fix precision issues when loading from CSV Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small tweak Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									4fb9fa7748
								
							
						
					
					
						commit
						35ab4a72ba
					
				@ -55,6 +55,14 @@ import java.util.*;
 | 
				
			|||||||
@Slf4j
 | 
					@Slf4j
 | 
				
			||||||
public class TFGraphMapper {
 | 
					public class TFGraphMapper {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * @deprecated Use static methods - {@link #importGraph(File)} etc
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
 | 
					    public static TFGraphMapper getInstance(){
 | 
				
			||||||
 | 
					        return new TFGraphMapper();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Import a frozen TensorFlow protobuf (.pb) file from the specified file
 | 
					     * Import a frozen TensorFlow protobuf (.pb) file from the specified file
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
 | 
				
			|||||||
@ -594,7 +594,7 @@ public class TFGraphTestAllHelper {
 | 
				
			|||||||
                                        val key = modelDir + "/" + okey;
 | 
					                                        val key = modelDir + "/" + okey;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                        // parse type directly
 | 
					                                        // parse type directly
 | 
				
			||||||
                                        val value = ArrayOptionsHelper.dataType(split[1]);
 | 
					                                        DataType value = ArrayOptionsHelper.dataType(split[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                                        // adding key directly
 | 
					                                        // adding key directly
 | 
				
			||||||
                                        //if (dtypes.containsKey(key))
 | 
					                                        //if (dtypes.containsKey(key))
 | 
				
			||||||
@ -672,12 +672,35 @@ public class TFGraphTestAllHelper {
 | 
				
			|||||||
            INDArray varValue;
 | 
					            INDArray varValue;
 | 
				
			||||||
            if(filtered.size() == 0){
 | 
					            if(filtered.size() == 0){
 | 
				
			||||||
                //Scalar
 | 
					                //Scalar
 | 
				
			||||||
                float[] varContents;
 | 
					                String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8);
 | 
				
			||||||
                try(InputStream is = new BufferedInputStream(resources.get(i).getSecond().getInputStream())){
 | 
					                switch (type){
 | 
				
			||||||
                    varContents = Nd4j.readNumpy(is, ",").data().asFloat();
 | 
					                    case DOUBLE:
 | 
				
			||||||
 | 
					                    case FLOAT:
 | 
				
			||||||
 | 
					                    case HALF:
 | 
				
			||||||
 | 
					                    case BFLOAT16:
 | 
				
			||||||
 | 
					                        varValue = Nd4j.scalar(type, parseDouble(content));
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    case LONG:
 | 
				
			||||||
 | 
					                    case INT:
 | 
				
			||||||
 | 
					                    case SHORT:
 | 
				
			||||||
 | 
					                    case UBYTE:
 | 
				
			||||||
 | 
					                    case BYTE:
 | 
				
			||||||
 | 
					                    case UINT16:
 | 
				
			||||||
 | 
					                    case UINT32:
 | 
				
			||||||
 | 
					                    case UINT64:
 | 
				
			||||||
 | 
					                        varValue = Nd4j.scalar(type, parseLong(content));
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    case BOOL:
 | 
				
			||||||
 | 
					                        varValue = Nd4j.scalar(parseBoolean(content));
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    case UTF8:
 | 
				
			||||||
 | 
					                        varValue = Nd4j.scalar(content);
 | 
				
			||||||
 | 
					                        break;
 | 
				
			||||||
 | 
					                    case COMPRESSED:
 | 
				
			||||||
 | 
					                    case UNKNOWN:
 | 
				
			||||||
 | 
					                    default:
 | 
				
			||||||
 | 
					                        throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                Preconditions.checkState(varContents.length == 1, "Expected length 1 content for scalar shape; got length %s", varContents.length);
 | 
					 | 
				
			||||||
                varValue = Nd4j.scalar(type, varContents[0]);
 | 
					 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                int[] varShape = new int[filtered.size()];
 | 
					                int[] varShape = new int[filtered.size()];
 | 
				
			||||||
                for( int j=0; j<filtered.size(); j++ ){
 | 
					                for( int j=0; j<filtered.size(); j++ ){
 | 
				
			||||||
@ -716,18 +739,55 @@ public class TFGraphTestAllHelper {
 | 
				
			|||||||
                            throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
 | 
					                            throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond());
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        content = content.replaceAll("False", "0");
 | 
					                        if(varShape.length == 1 && varShape[0] == 0)        //Annoyingly, some scalars have shape [0] instead of []
 | 
				
			||||||
                        content = content.replaceAll("True", "1");
 | 
					                            varShape = new int[0];
 | 
				
			||||||
                        val varContents = Nd4j.readNumpy(new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8)), ",").data().asDouble();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if (varShape.length == 1) {
 | 
					                        String[] cLines = content.split("\n");
 | 
				
			||||||
                            if (varShape[0] == 0) {
 | 
					                        switch (type){
 | 
				
			||||||
                                varValue = Nd4j.scalar(type, varContents[0]);
 | 
					                            case DOUBLE:
 | 
				
			||||||
                            } else {
 | 
					                            case FLOAT:
 | 
				
			||||||
                                varValue = Nd4j.create(varContents, new long[]{varContents.length}, type);
 | 
					                            case HALF:
 | 
				
			||||||
 | 
					                            case BFLOAT16:
 | 
				
			||||||
 | 
					                                double[] dArr = new double[cLines.length];
 | 
				
			||||||
 | 
					                                int x=0;
 | 
				
			||||||
 | 
					                                while(x < dArr.length){
 | 
				
			||||||
 | 
					                                    dArr[x] = parseDouble(cLines[x]);
 | 
				
			||||||
 | 
					                                    x++;
 | 
				
			||||||
                                }
 | 
					                                }
 | 
				
			||||||
                        } else {
 | 
					                                varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape);
 | 
				
			||||||
                            varValue = Nd4j.create(varContents, ArrayUtil.toLongArray(varShape), type);
 | 
					                                break;
 | 
				
			||||||
 | 
					                            case LONG:
 | 
				
			||||||
 | 
					                            case INT:
 | 
				
			||||||
 | 
					                            case SHORT:
 | 
				
			||||||
 | 
					                            case UBYTE:
 | 
				
			||||||
 | 
					                            case BYTE:
 | 
				
			||||||
 | 
					                            case UINT16:
 | 
				
			||||||
 | 
					                            case UINT32:
 | 
				
			||||||
 | 
					                            case UINT64:
 | 
				
			||||||
 | 
					                                long[] lArr = new long[cLines.length];
 | 
				
			||||||
 | 
					                                int y=0;
 | 
				
			||||||
 | 
					                                while(y < lArr.length){
 | 
				
			||||||
 | 
					                                    lArr[y] = parseLong(cLines[y]);
 | 
				
			||||||
 | 
					                                    y++;
 | 
				
			||||||
 | 
					                                }
 | 
				
			||||||
 | 
					                                varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape);
 | 
				
			||||||
 | 
					                                break;
 | 
				
			||||||
 | 
					                            case BOOL:
 | 
				
			||||||
 | 
					                                boolean[] bArr = new boolean[cLines.length];
 | 
				
			||||||
 | 
					                                int z=0;
 | 
				
			||||||
 | 
					                                while(z < bArr.length){
 | 
				
			||||||
 | 
					                                    bArr[z] = parseBoolean(cLines[z]);
 | 
				
			||||||
 | 
					                                    z++;
 | 
				
			||||||
 | 
					                                }
 | 
				
			||||||
 | 
					                                varValue = Nd4j.createFromArray(bArr).reshape('c', varShape);
 | 
				
			||||||
 | 
					                                break;
 | 
				
			||||||
 | 
					                            case UTF8:
 | 
				
			||||||
 | 
					                                varValue = Nd4j.create(cLines).reshape('c', varShape);
 | 
				
			||||||
 | 
					                                break;
 | 
				
			||||||
 | 
					                            case COMPRESSED:
 | 
				
			||||||
 | 
					                            case UNKNOWN:
 | 
				
			||||||
 | 
					                            default:
 | 
				
			||||||
 | 
					                                throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type);
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } catch (NumberFormatException e) {
 | 
					                } catch (NumberFormatException e) {
 | 
				
			||||||
@ -741,6 +801,39 @@ public class TFGraphTestAllHelper {
 | 
				
			|||||||
        return varMap;
 | 
					        return varMap;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static long parseLong(String line){
 | 
				
			||||||
 | 
					        line = line.trim();       //Handle whitespace
 | 
				
			||||||
 | 
					        if(line.matches("-?\\d+\\.0+")){
 | 
				
			||||||
 | 
					            //Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000"
 | 
				
			||||||
 | 
					            return Long.parseLong(line.substring(0, line.indexOf('.')));
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            return Long.parseLong(line);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static double parseDouble(String line){
 | 
				
			||||||
 | 
					        line = line.trim();   //Handle whitespace - some lines are like "      -inf"
 | 
				
			||||||
 | 
					        if("nan".equalsIgnoreCase(line)){
 | 
				
			||||||
 | 
					            return Double.NaN;
 | 
				
			||||||
 | 
					        } else if("inf".equalsIgnoreCase(line)) {
 | 
				
			||||||
 | 
					            return Double.POSITIVE_INFINITY;
 | 
				
			||||||
 | 
					        } else if("-inf".equalsIgnoreCase(line)){
 | 
				
			||||||
 | 
					            return Double.NEGATIVE_INFINITY;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            return Double.parseDouble(line);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private static boolean parseBoolean(String line){
 | 
				
			||||||
 | 
					        line = line.trim();
 | 
				
			||||||
 | 
					        if(line.matches("1(\\.0*)?")){          //Booleans are ocassionally represented like 1.000000 or 0.000000
 | 
				
			||||||
 | 
					            return true;
 | 
				
			||||||
 | 
					        } else if(line.matches("0(\\.0*)?")){
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return Boolean.parseBoolean(line);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static Pair<Double,Double> testPrecisionOverride(String testName){
 | 
					    public static Pair<Double,Double> testPrecisionOverride(String testName){
 | 
				
			||||||
        if("conv_4".equalsIgnoreCase(testName)){
 | 
					        if("conv_4".equalsIgnoreCase(testName)){
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user