diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java index fa005e52e..ecb48f922 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeNearestNeighbor.java @@ -16,10 +16,14 @@ package org.nd4j.linalg.api.ops.impl.image; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -53,4 +57,13 @@ public class ResizeNearestNeighbor extends DynamicCustomOp { public List doDiff(List f1) { throw new UnsupportedOperationException(); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), + "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + if(inputDataTypes.get(0).isFPType()) + return Collections.singletonList(inputDataTypes.get(0)); + return Collections.singletonList(Nd4j.defaultFloatingPointType()); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index cd4955ed5..9668ddd84 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -69,8 +69,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we //2019/07/03 - Out of Memory error "compression_residual_gru", - //2019/07/03 - calculateOutputDataTypes() has not been implemented for org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor - // https://github.com/eclipse/deeplearning4j/issues/7976 + //2019/07/03 - Out of Memory error "deeplabv3_xception_ade20k_train", //2019/07/03 - o.n.i.g.t.TFGraphMapper - No TensorFlow descriptor found for tensor "sample_sequence/model/h0/attn/MatMul", op "BatchMatMulV2"