Another small fix (#251)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-10 13:14:29 +10:00 committed by GitHub
parent 3fb9aecb59
commit f91970734b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -16,6 +16,7 @@
package org.nd4j.imports.graphmapper;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j;
@ -225,8 +226,8 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
//TODO work out which!
SDVariable v;
if(shape == null){
//No shape -> probably not a variable...
if(shape == null || ArrayUtil.contains(shape, 0)){
//No shape, or 0 in shape -> probably not a variable...
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
} else {
v = diff.var(entry.getKey(), dt, shape);

View File

@ -66,6 +66,26 @@ public class ArrayUtil {
return false;
}
public static boolean contains(int[] arr, int value){
if(arr == null)
return false;
for( int i : arr ) {
if (i == value)
return true;
}
return false;
}
public static boolean contains(long[] arr, int value){
if(arr == null)
return false;
for( long i : arr ) {
if (i == value)
return true;
}
return false;
}
/**
*
* @param arrs