parent
3fb9aecb59
commit
f91970734b
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.imports.graphmapper;
|
package org.nd4j.imports.graphmapper;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.shade.protobuf.Message;
|
import org.nd4j.shade.protobuf.Message;
|
||||||
import org.nd4j.shade.protobuf.TextFormat;
|
import org.nd4j.shade.protobuf.TextFormat;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -225,8 +226,8 @@ public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_T
|
||||||
//TODO work out which!
|
//TODO work out which!
|
||||||
|
|
||||||
SDVariable v;
|
SDVariable v;
|
||||||
if(shape == null){
|
if(shape == null || ArrayUtil.contains(shape, 0)){
|
||||||
//No shape -> probably not a variable...
|
//No shape, or 0 in shape -> probably not a variable...
|
||||||
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
|
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
|
||||||
} else {
|
} else {
|
||||||
v = diff.var(entry.getKey(), dt, shape);
|
v = diff.var(entry.getKey(), dt, shape);
|
||||||
|
|
|
@ -66,6 +66,26 @@ public class ArrayUtil {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean contains(int[] arr, int value){
|
||||||
|
if(arr == null)
|
||||||
|
return false;
|
||||||
|
for( int i : arr ) {
|
||||||
|
if (i == value)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean contains(long[] arr, int value){
|
||||||
|
if(arr == null)
|
||||||
|
return false;
|
||||||
|
for( long i : arr ) {
|
||||||
|
if (i == value)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param arrs
|
* @param arrs
|
||||||
|
|
Loading…
Reference in New Issue