fix for #7976, update test comment (to OOM) (#45)

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-07-05 19:09:47 -07:00 committed by AlexDBlack
parent 9f401dc020
commit 4b23dc01fd
2 changed files with 14 additions and 2 deletions

View File

@ -16,10 +16,14 @@
package org.nd4j.linalg.api.ops.impl.image; package org.nd4j.linalg.api.ops.impl.image;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -53,4 +57,13 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
if(inputDataTypes.get(0).isFPType())
return Collections.singletonList(inputDataTypes.get(0));
return Collections.singletonList(Nd4j.defaultFloatingPointType());
}
} }

View File

@ -69,8 +69,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
//2019/07/03 - Out of Memory error //2019/07/03 - Out of Memory error
"compression_residual_gru", "compression_residual_gru",
//2019/07/03 - calculateOutputDataTypes() has not been implemented for org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor //2019/07/03 - Out of Memory error
// https://github.com/eclipse/deeplearning4j/issues/7976
"deeplabv3_xception_ade20k_train", "deeplabv3_xception_ade20k_train",
//2019/07/03 - o.n.i.g.t.TFGraphMapper - No TensorFlow descriptor found for tensor "sample_sequence/model/h0/attn/MatMul", op "BatchMatMulV2" //2019/07/03 - o.n.i.g.t.TFGraphMapper - No TensorFlow descriptor found for tensor "sample_sequence/model/h0/attn/MatMul", op "BatchMatMulV2"