diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java index fe252aeeb..95f238973 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java @@ -16,6 +16,7 @@ package org.nd4j.imports.graphmapper; +import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.shade.protobuf.Message; import org.nd4j.shade.protobuf.TextFormat; import lombok.extern.slf4j.Slf4j; @@ -225,8 +226,8 @@ public abstract class BaseGraphMapper probably not a variable... + if(shape == null || ArrayUtil.contains(shape, 0)){ + //No shape, or 0 in shape -> probably not a variable... v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null); } else { v = diff.var(entry.getKey(), dt, shape); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index 2fe33dfba..cf54d4357 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -66,6 +66,26 @@ public class ArrayUtil { return false; } + public static boolean contains(int[] arr, int value){ + if(arr == null) + return false; + for( int i : arr ) { + if (i == value) + return true; + } + return false; + } + + public static boolean contains(long[] arr, int value){ + if(arr == null) + return false; + for( long i : arr ) { + if (i == value) + return true; + } + return false; + } + /** * * @param arrs