Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
9f401dc020
commit
4b23dc01fd
|
@ -16,10 +16,14 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.image;
|
package org.nd4j.linalg.api.ops.impl.image;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
import org.tensorflow.framework.NodeDef;
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
@ -53,4 +57,13 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
||||||
|
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
|
if(inputDataTypes.get(0).isFPType())
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
return Collections.singletonList(Nd4j.defaultFloatingPointType());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,8 +69,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
|
||||||
//2019/07/03 - Out of Memory error
|
//2019/07/03 - Out of Memory error
|
||||||
"compression_residual_gru",
|
"compression_residual_gru",
|
||||||
|
|
||||||
//2019/07/03 - calculateOutputDataTypes() has not been implemented for org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor
|
//2019/07/03 - Out of Memory error
|
||||||
// https://github.com/eclipse/deeplearning4j/issues/7976
|
|
||||||
"deeplabv3_xception_ade20k_train",
|
"deeplabv3_xception_ade20k_train",
|
||||||
|
|
||||||
//2019/07/03 - o.n.i.g.t.TFGraphMapper - No TensorFlow descriptor found for tensor "sample_sequence/model/h0/attn/MatMul", op "BatchMatMulV2"
|
//2019/07/03 - o.n.i.g.t.TFGraphMapper - No TensorFlow descriptor found for tensor "sample_sequence/model/h0/attn/MatMul", op "BatchMatMulV2"
|
||||||
|
|
Loading…
Reference in New Issue