Another small fix (#251)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-10 13:14:29 +10:00 committed by GitHub
parent 3fb9aecb59
commit f91970734b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -16,6 +16,7 @@
package org.nd4j.imports.graphmapper; package org.nd4j.imports.graphmapper;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.protobuf.Message; import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat; import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -225,8 +226,8 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
//TODO work out which! //TODO work out which!
SDVariable v; SDVariable v;
if(shape == null){ if(shape == null || ArrayUtil.contains(shape, 0)){
//No shape -> probably not a variable... //No shape, or 0 in shape -> probably not a variable...
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null); v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
} else { } else {
v = diff.var(entry.getKey(), dt, shape); v = diff.var(entry.getKey(), dt, shape);

View File

@ -66,6 +66,26 @@ public class ArrayUtil {
return false; return false;
} }
public static boolean contains(int[] arr, int value){
if(arr == null)
return false;
for( int i : arr ) {
if (i == value)
return true;
}
return false;
}
public static boolean contains(long[] arr, int value){
if(arr == null)
return false;
for( long i : arr ) {
if (i == value)
return true;
}
return false;
}
/** /**
* *
* @param arrs * @param arrs