LeakyReLU fix (#55)
* LeakyReLU: Use serScalar to set alpha correctly in TF import LogX: remove incorrect TF mapping Pow: remove TF import method (no mapping) BaseOp: remove duplicate extraArgs Signed-off-by: Ryan Nett <rnett@skymind.io> * un-ignore cifar-10 gan, as it is now passing Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
027d4d2a47
commit
2d991f5445
|
@ -48,7 +48,7 @@ import java.util.Map;
|
||||||
public abstract class BaseOp extends DifferentialFunction implements Op {
|
public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
|
|
||||||
protected INDArray x, y, z;
|
protected INDArray x, y, z;
|
||||||
protected Object[] extraArgs;
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
protected String xVertexId,yVertexId,zVertexId;
|
protected String xVertexId,yVertexId,zVertexId;
|
||||||
// cached instance, for dataType checks
|
// cached instance, for dataType checks
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.graph.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||||
|
@ -26,6 +27,10 @@ import java.util.Arrays;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Leaky Rectified linear unit. Default alpha=0.01, cutoff=0<br>
|
* Leaky Rectified linear unit. Default alpha=0.01, cutoff=0<br>
|
||||||
|
@ -106,4 +111,12 @@ public class LeakyReLU extends BaseScalarOp {
|
||||||
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0));
|
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0));
|
||||||
return Arrays.asList(ret);
|
return Arrays.asList(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode,
|
||||||
|
GraphDef graph) {
|
||||||
|
alpha = attributesForNode.get("alpha").getF();
|
||||||
|
extraArgs = new Object[]{alpha};
|
||||||
|
this.setScalar(Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, alpha));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,6 +81,6 @@ public class LogX extends BaseScalarOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String tensorflowName() {
|
public String tensorflowName() {
|
||||||
return "LogX";
|
throw new NoOpNameFoundException("No TensorFlow op found for " + getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,19 +78,6 @@ public class Pow extends BaseScalarOp {
|
||||||
return "pow";
|
return "pow";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
val weightsName = nodeDef.getInput(1);
|
|
||||||
val tmp = initWith.getArrForVarName(weightsName);
|
|
||||||
|
|
||||||
// if second argument is scalar - we should provide array of same shape
|
|
||||||
if (tmp != null) {
|
|
||||||
if (tmp.isScalar()) {
|
|
||||||
this.pow = tmp.getDouble(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "Pow";
|
return "Pow";
|
||||||
|
|
|
@ -59,9 +59,6 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
|
||||||
//2019/07/10 - Libnd4j assign error - https://github.com/eclipse/deeplearning4j/issues/8002
|
//2019/07/10 - Libnd4j assign error - https://github.com/eclipse/deeplearning4j/issues/8002
|
||||||
"xlnet_cased_L-24_H-1024_A-16",
|
"xlnet_cased_L-24_H-1024_A-16",
|
||||||
|
|
||||||
//2019/06/28 - Output incorrect, can't debug b/c https://github.com/eclipse/deeplearning4j/issues/7957
|
|
||||||
"cifar10_gan_85",
|
|
||||||
|
|
||||||
//2019/07/03 - Out of Memory error
|
//2019/07/03 - Out of Memory error
|
||||||
"compression_residual_gru",
|
"compression_residual_gru",
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue