LeakyReLU fix (#55)
* LeakyReLU: Use serScalar to set alpha correctly in TF import LogX: remove incorrect TF mapping Pow: remove TF import method (no mapping) BaseOp: remove duplicate extraArgs Signed-off-by: Ryan Nett <rnett@skymind.io> * un-ignore cifar-10 gan, as it is now passing Signed-off-by: Ryan Nett <rnett@skymind.io>
This commit is contained in:
		
							parent
							
								
									027d4d2a47
								
							
						
					
					
						commit
						2d991f5445
					
				| @ -48,7 +48,7 @@ import java.util.Map; | ||||
| public abstract class BaseOp extends DifferentialFunction implements Op { | ||||
| 
 | ||||
|     protected INDArray x, y, z; | ||||
|     protected Object[] extraArgs; | ||||
| 
 | ||||
|     @Getter @Setter | ||||
|     protected String xVertexId,yVertexId,zVertexId; | ||||
|     // cached instance, for dataType checks | ||||
|  | ||||
| @ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.scalar; | ||||
| 
 | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.graph.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.BaseScalarOp; | ||||
| import org.nd4j.linalg.api.ops.BaseTransformOp; | ||||
| @ -26,6 +27,10 @@ import java.util.Arrays; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.tensorflow.framework.AttrValue; | ||||
| import org.tensorflow.framework.GraphDef; | ||||
| import org.tensorflow.framework.NodeDef; | ||||
| 
 | ||||
| /** | ||||
|  * Leaky Rectified linear unit. Default alpha=0.01, cutoff=0<br> | ||||
| @ -106,4 +111,12 @@ public class LeakyReLU extends BaseScalarOp { | ||||
|         SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0)); | ||||
|         return Arrays.asList(ret); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, | ||||
|             GraphDef graph) { | ||||
|         alpha = attributesForNode.get("alpha").getF(); | ||||
|         extraArgs = new Object[]{alpha}; | ||||
|         this.setScalar(Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, alpha)); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -81,6 +81,6 @@ public class LogX extends BaseScalarOp { | ||||
| 
 | ||||
|     @Override | ||||
|     public String tensorflowName() { | ||||
|         return "LogX"; | ||||
|         throw new NoOpNameFoundException("No TensorFlow op found for " + getClass().getSimpleName()); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -78,19 +78,6 @@ public class Pow extends BaseScalarOp { | ||||
|         return "pow"; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { | ||||
|         val weightsName = nodeDef.getInput(1); | ||||
|         val tmp = initWith.getArrForVarName(weightsName); | ||||
| 
 | ||||
|         // if second argument is scalar - we should provide array of same shape | ||||
|         if (tmp != null) { | ||||
|             if (tmp.isScalar()) { | ||||
|                 this.pow = tmp.getDouble(0); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String onnxName() { | ||||
|         return "Pow"; | ||||
|  | ||||
| @ -59,9 +59,6 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we | ||||
|             //2019/07/10 - Libnd4j assign error - https://github.com/eclipse/deeplearning4j/issues/8002 | ||||
|             "xlnet_cased_L-24_H-1024_A-16", | ||||
| 
 | ||||
|             //2019/06/28 - Output incorrect, can't debug b/c https://github.com/eclipse/deeplearning4j/issues/7957 | ||||
|             "cifar10_gan_85", | ||||
| 
 | ||||
|             //2019/07/03 - Out of Memory error | ||||
|             "compression_residual_gru", | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user