parent
3fb9aecb59
commit
f91970734b
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -225,8 +226,8 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
|
|||
//TODO work out which!
|
||||
|
||||
SDVariable v;
|
||||
if(shape == null){
|
||||
//No shape -> probably not a variable...
|
||||
if(shape == null || ArrayUtil.contains(shape, 0)){
|
||||
//No shape, or 0 in shape -> probably not a variable...
|
||||
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
|
||||
} else {
|
||||
v = diff.var(entry.getKey(), dt, shape);
|
||||
|
|
|
@ -66,6 +66,26 @@ public class ArrayUtil {
|
|||
return false;
|
||||
}
|
||||
|
||||
public static boolean contains(int[] arr, int value){
|
||||
if(arr == null)
|
||||
return false;
|
||||
for( int i : arr ) {
|
||||
if (i == value)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean contains(long[] arr, int value){
|
||||
if(arr == null)
|
||||
return false;
|
||||
for( long i : arr ) {
|
||||
if (i == value)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param arrs
|
||||
|
|
Loading…
Reference in New Issue