Tensorflow import tests and fixes (#435)
* ignored ops checked Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * reconfigured AdjustContrast + commented primitive_gru Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor changes + exception ops commented Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * figured out non existent tf ops and random ops check Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor changes to tensorflowop and randomness cheks Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * deconv2d tensorfloname removed * Fix Flatbuffers ser/de with character fields Signed-off-by: Alex Black <blacka101@gmail.com> * TFGraphTestAllSameDiff tests passed except NonMaxSuppression Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * minor changes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * temporary ignored section added Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ignores removed Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * org.nd4j.base.Preconditions -> org.nd4j.common.base.Preconditions Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * temsorflownames reverts and replace CopyHost * ignored mod op tests due to known issue Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * rsestored mod after fixing in cpp level Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * ignored random_shuffle op test due to known issue Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * increased random_uniform mean/std comparator sensitivity Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * igmored random tests due to SameDiff RNG seed is not set. Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
6e9c849e4a
commit
ec757f654d
|
@ -261,6 +261,10 @@ public abstract class DifferentialFunction {
|
||||||
if(target.getType() == float.class && value instanceof Double){
|
if(target.getType() == float.class && value instanceof Double){
|
||||||
value = ((Double) value).floatValue();
|
value = ((Double) value).floatValue();
|
||||||
}
|
}
|
||||||
|
//Edge case: we store char fields as integers, rather than introduce an extra property
|
||||||
|
if(target.getType() == char.class && value instanceof Integer){
|
||||||
|
value = (char)((Integer)value).intValue();
|
||||||
|
}
|
||||||
|
|
||||||
target.set(this,value);
|
target.set(this,value);
|
||||||
} catch (IllegalAccessException e) {
|
} catch (IllegalAccessException e) {
|
||||||
|
|
|
@ -483,6 +483,8 @@ public class FlatBuffersMapper {
|
||||||
//No op
|
//No op
|
||||||
} else if (v instanceof Boolean) {
|
} else if (v instanceof Boolean) {
|
||||||
b = new boolean[]{(Boolean) v};
|
b = new boolean[]{(Boolean) v};
|
||||||
|
} else if(v instanceof Character){
|
||||||
|
i = new int[]{(Character)v};
|
||||||
} else if (v instanceof Number) {
|
} else if (v instanceof Number) {
|
||||||
if (v instanceof Double) {
|
if (v instanceof Double) {
|
||||||
d = new double[]{(Double) v};
|
d = new double[]{(Double) v};
|
||||||
|
|
|
@ -1220,7 +1220,12 @@ public class OpValidation {
|
||||||
"absargmax",
|
"absargmax",
|
||||||
"absargmin",
|
"absargmin",
|
||||||
"entropy_shannon", //This is a thing, but quite different from our op: https://www.tensorflow.org/versions/r1.2/api_docs/python/tf/contrib/bayesflow/entropy/entropy_shannon
|
"entropy_shannon", //This is a thing, but quite different from our op: https://www.tensorflow.org/versions/r1.2/api_docs/python/tf/contrib/bayesflow/entropy/entropy_shannon
|
||||||
"count_zero"
|
"count_zero",
|
||||||
|
|
||||||
|
"SaveV2",
|
||||||
|
"LoadV2",
|
||||||
|
"RestoreV2",
|
||||||
|
"RandomCrop" // NotImplementedError: Op RandomCrop is not available in GraphDef version 134. It has been removed in version 8. Random crop is now pure Python.
|
||||||
);
|
);
|
||||||
|
|
||||||
return new HashSet<>(list);
|
return new HashSet<>(list);
|
||||||
|
|
|
@ -625,7 +625,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.compat.CompatSparseToDense.class,
|
org.nd4j.linalg.api.ops.compat.CompatSparseToDense.class,
|
||||||
org.nd4j.linalg.api.ops.compat.CompatStringSplit.class,
|
org.nd4j.linalg.api.ops.compat.CompatStringSplit.class,
|
||||||
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
||||||
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
|
||||||
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
||||||
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
|
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
|
||||||
org.nd4j.linalg.api.ops.custom.RgbToYiq.class,
|
org.nd4j.linalg.api.ops.custom.RgbToYiq.class,
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2019 Konduit K.K.
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
|
@ -19,14 +18,27 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
public class AdjustContrast extends BaseAdjustContrast {
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public AdjustContrast() {super();}
|
public class AdjustContrast extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdjustContrast() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
|
public AdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
|
||||||
super(in, factor, out);
|
Preconditions.checkArgument(in.rank() >= 3,
|
||||||
|
"AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank());
|
||||||
|
inputArguments.add(in);
|
||||||
|
outputArguments.add(out);
|
||||||
|
|
||||||
|
addTArgument(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
public AdjustContrast(@NonNull INDArray in, double factor) {
|
public AdjustContrast(@NonNull INDArray in, double factor) {
|
||||||
|
@ -44,11 +56,18 @@ public class AdjustContrast extends BaseAdjustContrast {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "adjust_contrast";
|
return "adjust_contrast_v2";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String[] tensorflowNames() {
|
||||||
return "AdjustContrast";
|
return new String[]{"AdjustContrast", "AdjustContrastv2"};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,44 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
* Copyright (c) 2019 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
public class AdjustContrastV2 extends BaseAdjustContrast {
|
|
||||||
|
|
||||||
public AdjustContrastV2() {super();}
|
|
||||||
|
|
||||||
public AdjustContrastV2(@NonNull INDArray in, double factor, INDArray out) {
|
|
||||||
super(in, factor, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
public AdjustContrastV2(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
|
||||||
super( sameDiff,new SDVariable[]{in,factor});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "adjust_contrast_v2";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "AdjustContrastv2";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
* Copyright (c) 2019 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
|
||||||
public BaseAdjustContrast() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseAdjustContrast(@NonNull INDArray in, double factor, INDArray out) {
|
|
||||||
Preconditions.checkArgument(in.rank() >= 3,
|
|
||||||
"AdjustContrast: op expects rank of input array to be >= 3, but got %s instead", in.rank());
|
|
||||||
inputArguments.add(in);
|
|
||||||
outputArguments.add(out);
|
|
||||||
|
|
||||||
addTArgument(factor);
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseAdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable[] vars) {
|
|
||||||
super("", sameDiff, vars);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
|
||||||
int n = args().length;
|
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
|
||||||
return Collections.singletonList(inputDataTypes.get(0));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,10 +17,15 @@ package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class CompareAndBitpack extends DynamicCustomOp {
|
public class CompareAndBitpack extends DynamicCustomOp {
|
||||||
public CompareAndBitpack() {}
|
public CompareAndBitpack() {}
|
||||||
|
|
||||||
|
@ -47,4 +52,11 @@ public class CompareAndBitpack extends DynamicCustomOp {
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "CompareAndBitpack";
|
return "CompareAndBitpack";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes);
|
||||||
|
Preconditions.checkState(dataTypes.get(0) == dataTypes.get(1), "Input data types must be the same: got %s", dataTypes);
|
||||||
|
return Collections.singletonList(DataType.UINT8);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -37,8 +37,4 @@ public class RgbToGrayscale extends DynamicCustomOp {
|
||||||
return "rgb_to_grs";
|
return "rgb_to_grs";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RgbToGrs";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,11 +42,6 @@ public class RgbToYiq extends DynamicCustomOp {
|
||||||
return "rgb_to_yiq";
|
return "rgb_to_yiq";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RgbToYiq";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
|
|
|
@ -42,11 +42,6 @@ public class RgbToYuv extends DynamicCustomOp {
|
||||||
return "rgb_to_yuv";
|
return "rgb_to_yuv";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RgbToYuv";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
|
|
|
@ -41,11 +41,6 @@ public class YiqToRgb extends DynamicCustomOp {
|
||||||
return "yiq_to_rgb";
|
return "yiq_to_rgb";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "YiqToRgb";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
int n = args().length;
|
int n = args().length;
|
||||||
|
|
|
@ -42,10 +42,6 @@ public class YuvToRgb extends DynamicCustomOp {
|
||||||
return "yuv_to_rgb";
|
return "yuv_to_rgb";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "YuvToRgb";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
|
|
@ -53,7 +53,7 @@ public class NonMaxSuppressionV3 extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"NonMaxSuppressionV3","NonMaxSuppressionV4"};
|
return new String[]{"NonMaxSuppressionV3","NonMaxSuppressionV4","NonMaxSuppressionV5"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -306,11 +306,6 @@ public class DeConv2D extends DynamicCustomOp {
|
||||||
return "ConvTranspose";
|
return "ConvTranspose";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Conv2DTranspose";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
|
|
@ -62,12 +62,6 @@ public class IsNonDecreasing extends DynamicCustomOp {
|
||||||
return "is_non_decreasing";
|
return "is_non_decreasing";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "IsNonDecreasing";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class CopyOp extends BaseTransformSameOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"Copy","DeepCopy","CopyHost"};
|
return new String[]{"Copy"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class Identity extends BaseDynamicTransformOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"Identity"};
|
return new String[]{"Identity", "DeepCopy", "CopyHost"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -55,10 +55,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
||||||
return "unsorted_segment_mean";
|
return "unsorted_segment_mean";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "UnsortedSegmentMean";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
||||||
return "unsorted_segment_sqrt_n";
|
return "unsorted_segment_sqrt_n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "UnsortedSegmentSqrtN";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return new UnsortedSegmentSqrtNBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs();
|
return new UnsortedSegmentSqrtNBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs();
|
||||||
|
|
|
@ -71,9 +71,7 @@ public class RandomGamma extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
if(attributesForNode.containsKey("alpha")) {
|
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("T").getType());
|
||||||
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("alpha").getType());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -84,12 +84,6 @@ public class DropOutInverted extends BaseRandomOp {
|
||||||
return "Dropout";
|
return "Dropout";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Dropout";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -100,12 +100,6 @@ public class UniformDistribution extends BaseRandomOp {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "RandomUniformGG";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
|
|
|
@ -851,7 +851,68 @@ public class TFGraphTestAllHelper {
|
||||||
return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true));
|
return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout") || modelName.equals("dropout"))
|
if(modelName.startsWith("empty")){
|
||||||
|
return (t, s) -> {
|
||||||
|
boolean areEqualShapes = t.equalShapes(s);
|
||||||
|
boolean areEqualDataTypes = t.dataType() == s.dataType();
|
||||||
|
return areEqualShapes && areEqualDataTypes;
|
||||||
|
}; }
|
||||||
|
|
||||||
|
// sum of all elements along dimesions before and after shuffle has to be the same
|
||||||
|
if(modelName.startsWith("random_shuffle")){
|
||||||
|
return (t, s) -> Nd4j.sort(t, true).equals(Nd4j.sort(s, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
if(modelName.startsWith("random_normal")){
|
||||||
|
return (t, s) -> {
|
||||||
|
boolean areEqualShapes = t.equalShapes(s);
|
||||||
|
double meanS = s.meanNumber().doubleValue();
|
||||||
|
double meanT = t.meanNumber().doubleValue();
|
||||||
|
double stdS = s.stdNumber().doubleValue();
|
||||||
|
double stdT = t.stdNumber().doubleValue();
|
||||||
|
double eps = 1;
|
||||||
|
return areEqualShapes && (Math.abs(meanS-meanT) < eps) && (Math.abs(stdS-stdT) < eps);
|
||||||
|
}; }
|
||||||
|
|
||||||
|
if(modelName.startsWith("random_gamma")){
|
||||||
|
return (t, s) -> {
|
||||||
|
boolean areEqualShapes = t.equalShapes(s);
|
||||||
|
boolean nonNegativeValues = (t.minNumber().doubleValue() > 0) && (t.minNumber().doubleValue() > 0);
|
||||||
|
double meanS = s.meanNumber().doubleValue();
|
||||||
|
double meanT = t.meanNumber().doubleValue();
|
||||||
|
double stdS = s.stdNumber().doubleValue();
|
||||||
|
double stdT = t.stdNumber().doubleValue();
|
||||||
|
double eps = 1;
|
||||||
|
return areEqualShapes && nonNegativeValues && (Math.abs(meanS-meanT) < eps) && (Math.abs(stdS-stdT) < eps);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if(modelName.startsWith("random_poisson") || modelName.startsWith("random_poisson_v2")){
|
||||||
|
return (t, s) -> {
|
||||||
|
boolean areEqualShapes = t.equalShapes(s);
|
||||||
|
boolean nonNegativeValues = (t.minNumber().doubleValue() >= 0) && (t.minNumber().doubleValue() >= 0);
|
||||||
|
double meanS = s.meanNumber().doubleValue();
|
||||||
|
double meanT = t.meanNumber().doubleValue();
|
||||||
|
double stdS = s.stdNumber().doubleValue();
|
||||||
|
double stdT = t.stdNumber().doubleValue();
|
||||||
|
double eps = 1;
|
||||||
|
return areEqualShapes && nonNegativeValues && (Math.abs(meanS-meanT) < eps) && (Math.abs(stdS-stdT) < eps);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if(modelName.startsWith("random_uniform")|| modelName.startsWith("random_uniform_int")){
|
||||||
|
return (t, s) -> {
|
||||||
|
boolean areEqualShapes = t.equalShapes(s);
|
||||||
|
double meanS = s.meanNumber().doubleValue();
|
||||||
|
double meanT = t.meanNumber().doubleValue();
|
||||||
|
double stdS = s.stdNumber().doubleValue();
|
||||||
|
double stdT = t.stdNumber().doubleValue();
|
||||||
|
double eps = 1;
|
||||||
|
return areEqualShapes && (Math.abs(stdS-stdT) < eps) && (Math.abs(meanS-meanT) < eps);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if(modelName.startsWith("alpha_dropout") || modelName.startsWith("layers_dropout") || modelName.startsWith("dropout"))
|
||||||
//We can't compare dropout using simple equality due to randomness
|
//We can't compare dropout using simple equality due to randomness
|
||||||
return (t, s) -> {
|
return (t, s) -> {
|
||||||
double[] tfNums = t.ravel().toDoubleVector();
|
double[] tfNums = t.ravel().toDoubleVector();
|
||||||
|
|
|
@ -66,23 +66,29 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
public static final String[] IGNORE_REGEXES = new String[]{
|
public static final String[] IGNORE_REGEXES = new String[]{
|
||||||
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
|
//Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958
|
||||||
//Still failing 2019/09/11
|
//Still failing 2019/09/11
|
||||||
|
//Still failing 2020/04/27
|
||||||
|
//java.lang.IllegalStateException: Requested output variable LogMatrixDeterminant:1 does not exist in SameDiff instance
|
||||||
"slogdet/.*",
|
"slogdet/.*",
|
||||||
|
|
||||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||||
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
||||||
"bincount/.*",
|
"bincount/.*",
|
||||||
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
|
// Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393
|
||||||
"is_strictly_increasing/emptyArrayTest/.*",
|
"is_strictly_increasing/emptyArrayTest/.*",
|
||||||
|
|
||||||
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
||||||
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
||||||
"truncatemod/.*",
|
"truncatemod/.*",
|
||||||
|
|
||||||
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
//Still failing as of 2019/09/11 - https://github.com/deeplearning4j/deeplearning4j/issues/6464 - not sure if related to: https://github.com/deeplearning4j/deeplearning4j/issues/6447
|
||||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME",
|
||||||
|
|
||||||
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
//2019/09/11 - No tensorflow op found for SparseTensorDenseAdd
|
||||||
|
// 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: SparseTensorDenseAdd
|
||||||
"confusion/.*",
|
"confusion/.*",
|
||||||
|
|
||||||
//2019/09/11 - Couple of tests failing (InferenceSession issues)
|
//2019/09/11 - Couple of tests failing (InferenceSession issues)
|
||||||
|
// Still failing 2020/04/27 Requested output variable concat does not exist in SameDiff instance
|
||||||
"rnn/bstack/d_.*",
|
"rnn/bstack/d_.*",
|
||||||
|
|
||||||
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
//2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere
|
||||||
|
@ -97,85 +103,66 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"g_11",
|
"g_11",
|
||||||
|
|
||||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||||
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: Multinomial
|
||||||
"multinomial/.*",
|
"multinomial/.*",
|
||||||
|
|
||||||
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
|
||||||
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find descriptor for op: deconv3d_tf - class: org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF
|
||||||
"conv3d_transpose.*",
|
"conv3d_transpose.*",
|
||||||
|
|
||||||
//2019/11/15 - mapping is not present yet https://github.com/eclipse/deeplearning4j/issues/8397
|
//2019/11/15 - mapping is not present yet https://github.com/eclipse/deeplearning4j/issues/8397
|
||||||
|
// Still failing 2020/04/27 java.lang.AssertionError: Predictions do not match on ragged/reduce_mean/2d_a1, node RaggedReduceMean/truediv
|
||||||
"ragged/reduce_mean/.*",
|
"ragged/reduce_mean/.*",
|
||||||
|
|
||||||
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
||||||
|
// Still failing 2020/04/27 java.lang.IndexOutOfBoundsException: 1
|
||||||
"zeros_like/rank2_float32_dtype_int.*",
|
"zeros_like/rank2_float32_dtype_int.*",
|
||||||
|
|
||||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
|
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
|
||||||
|
// Still failing 2020/04/27 java.lang.AssertionError: Predictions do not match on roll/rank2_float32_zeroshift, node Roll
|
||||||
"roll/.*",
|
"roll/.*",
|
||||||
|
|
||||||
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
||||||
|
// still failing 2020/04/27
|
||||||
|
// java.lang.IllegalStateException: Failed to calculate output shapes for op matrix_band_part (MatrixBandPart) - no shapes were returned by calculateOutputShape()
|
||||||
"matrix_band_part/.*",
|
"matrix_band_part/.*",
|
||||||
|
|
||||||
// 12.20.2019 - https://github.com/eclipse/deeplearning4j/issues/8559
|
// 12.20.2019 - https://github.com/eclipse/deeplearning4j/issues/8559
|
||||||
|
// Still failing 2020/27/04 java.lang.AssertionError: Predictions do not match on fused_batch_norm/float32_nhcw, node FusedBatchNormV3
|
||||||
"fused_batch_norm/.*",
|
"fused_batch_norm/.*",
|
||||||
|
|
||||||
// AB 2020/01/04 - https://github.com/eclipse/deeplearning4j/issues/8592
|
// 01.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8898
|
||||||
"emptyArrayTests/reshape/rank2_shape2-0_2-0--1",
|
"primitive_gru",
|
||||||
|
|
||||||
//AB 2020/01/07 - Known issues
|
// 05.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8921
|
||||||
"bitcast/from_float64_to_int64",
|
"random_poisson/rank1_float16", "random_poisson/rank1_float32", "random_poisson/rank1_float16", "random_poisson/rank1_half",
|
||||||
"bitcast/from_rank2_float64_to_int64",
|
"random_poisson_v2/rank1_float64", "random_poisson_v2/rank1_float16", "random_poisson_v2/rank1_half",
|
||||||
"bitcast/from_float64_to_uint64",
|
|
||||||
|
|
||||||
|
//08.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8927
|
||||||
//NEWLY ADDED TESTCASES from 27/04/2020
|
|
||||||
"non_max_suppression_v2/.*", "non_max_suppression/.*",
|
|
||||||
"random_gamma/.*",
|
"random_gamma/.*",
|
||||||
"non_max_suppression_v5/.*",
|
|
||||||
"non_max_suppression_v4/.*",
|
//08.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8928
|
||||||
"non_max_suppression_v3/.*",
|
|
||||||
"dropout/.*",
|
|
||||||
"max_pool_with_argmax/.*",
|
|
||||||
"conv2d_transpose/.*",
|
|
||||||
"Conv3DBackpropInputV2/.*",
|
"Conv3DBackpropInputV2/.*",
|
||||||
"Conv3DBackpropInput/.*",
|
|
||||||
"mod/.*",
|
//12.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8940
|
||||||
"leaky_relu/.*",
|
"compare_and_bitpack/.*",
|
||||||
"DeepCopy/.*",
|
|
||||||
"empty/.*",
|
//12.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8943
|
||||||
"ones_like/.*",
|
"max_pool_with_argmax/int64_int64_padding_SAME", "max_pool_with_argmax/int32_int64_padding_SAME",
|
||||||
"is_non_decreasing/.*",
|
|
||||||
"div/.*",
|
//12.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8946
|
||||||
"lgamma/.*",
|
"non_max_suppression_v4/.*","non_max_suppression_v5/.*",
|
||||||
|
|
||||||
|
// 18.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8960
|
||||||
|
"random_shuffle/.*",
|
||||||
|
// 18.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8963
|
||||||
"random_uniform/.*",
|
"random_uniform/.*",
|
||||||
"random_uniform_int/.*",
|
"random_uniform_int/.*",
|
||||||
"resize_area/.*",
|
|
||||||
"zeros_like_tf1/.*",
|
|
||||||
"Conv2DTranspose/.*",
|
|
||||||
"rgb_to_yuv/.*",
|
|
||||||
"rgb_to_grayscale/.*",
|
|
||||||
"rgb_to_yiq/.*",
|
|
||||||
"losses/.*",
|
|
||||||
"yiq_to_rgb/.*",
|
|
||||||
"yuv_to_rgb/.*",
|
|
||||||
"emptyArrayTests/.*",
|
|
||||||
"random_normal/.*",
|
"random_normal/.*",
|
||||||
"random_shuffle/.*",
|
"random_gamma/.*",
|
||||||
"random_poisson_v2/.*",
|
|
||||||
"random_poisson/.*",
|
"random_poisson/.*",
|
||||||
"random_crop/.*",
|
"random_poisson/.*",
|
||||||
"compare_and_bitpack/.*",
|
"random_poisson_v2/.*",
|
||||||
"adjust_contrast/.*",
|
|
||||||
"confusion/.*",
|
|
||||||
"bitcast/.*",
|
|
||||||
"roll/.*",
|
|
||||||
"matrix_band_part/.*",
|
|
||||||
"conv3d_transpose_layers/.*",
|
|
||||||
"multinomial/.*",
|
|
||||||
"unsorted_segment/.*",
|
|
||||||
"cnn2d_nn/.*",
|
|
||||||
"truncatemod/.*",
|
|
||||||
"bincount/.*",
|
|
||||||
"slogdet/.*",
|
|
||||||
"adjust_contrast_v2/.*"
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -847,22 +847,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertArrayEquals(new long[]{256, 256, 3}, lsd.get(0).getShape());
|
assertArrayEquals(new long[]{256, 256, 3}, lsd.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testAdjustContrastV2() {
|
|
||||||
INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3);
|
|
||||||
INDArray out = Nd4j.createUninitialized(4,4,3);
|
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
|
||||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
|
||||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
|
||||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
|
||||||
}).reshape(4,4,3);
|
|
||||||
|
|
||||||
Nd4j.exec(new AdjustContrastV2(in, 2.0, out));
|
|
||||||
|
|
||||||
assertArrayEquals(out.shape(), in.shape());
|
|
||||||
assertEquals(expected, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374")
|
@Ignore("AS 11/13/2019 https://github.com/eclipse/deeplearning4j/issues/8374")
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue