SameDiff If, While, and Misc changes (#52)
* softmax and logSoftmax w/ dimension Signed-off-by: Ryan Nett <rnett@skymind.io> * start of while Signed-off-by: Ryan Nett <rnett@skymind.io> * if, start of javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * while foreward pass working, backprop WIP Signed-off-by: Ryan Nett <rnett@skymind.io> * no backprop Signed-off-by: Ryan Nett <rnett@skymind.io> * Tensorflow style if/while (& tests), name scope fixes (and test), argument interceptor (for if/while), use '_' in op names instead of ':' Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * many fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * many fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * Some fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * cleanup if condition doesn't return boolean Signed-off-by: Ryan Nett <rnett@skymind.io> * serialization fix Signed-off-by: Ryan Nett <rnett@skymind.io> * use constants instead of magic numbers Signed-off-by: Ryan Nett <rnett@skymind.io>
This commit is contained in:
		
							parent
							
								
									2d991f5445
								
							
						
					
					
						commit
						daf3950d8d
					
				@ -451,6 +451,17 @@ public abstract class DifferentialFunction {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public void replaceArg(int i, SDVariable newArg){
 | 
				
			||||||
 | 
					        if(sameDiff != null){
 | 
				
			||||||
 | 
					            sameDiff.replaceArgFor(i, newArg, this);
 | 
				
			||||||
 | 
					            if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){
 | 
				
			||||||
 | 
					                sameDiff.removePropertyToResolve(this, args()[i].getVarName());
 | 
				
			||||||
 | 
					            } else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){
 | 
				
			||||||
 | 
					                sameDiff.addPropertyToResolve(this, newArg.getVarName());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Return the output variables for this differential function.
 | 
					     * Return the output variables for this differential function.
 | 
				
			||||||
@ -652,9 +663,9 @@ public abstract class DifferentialFunction {
 | 
				
			|||||||
                    scope = "";
 | 
					                    scope = "";
 | 
				
			||||||
                else
 | 
					                else
 | 
				
			||||||
                    scope = scope + "/";
 | 
					                    scope = scope + "/";
 | 
				
			||||||
                String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
 | 
					                String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_");
 | 
				
			||||||
                while(sameDiff.functionExists(varName)) {
 | 
					                while(sameDiff.functionExists(varName)) {
 | 
				
			||||||
                    varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
 | 
					                    varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_");
 | 
				
			||||||
                    argIndex++;
 | 
					                    argIndex++;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.autodiff.functions;
 | 
					package org.nd4j.autodiff.functions;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.lang.reflect.Method;
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
 | 
					import java.util.HashMap;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
@ -30,36 +35,183 @@ import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			|||||||
import org.nd4j.linalg.api.ops.NoOp;
 | 
					import org.nd4j.linalg.api.ops.NoOp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
 | 
					import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
 | 
					import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
 | 
					import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.*;
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.loss.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.loss.bp.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.HingeLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.HuberLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.L2Loss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.LogLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.Moments;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.*;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.AMax;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.AMax;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.AMin;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.AMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.Max;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.Max;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.Min;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.Min;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.*;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.Prod;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
 | 
					import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.LogX;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.scalar.Pow;
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.Pow;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.scalar.*;
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*;
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.scatter.*;
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.*;
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.Step;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Broadcast;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Concat;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Cross;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Diag;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.ExpandDims;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Gather;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.GatherNd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.MeshGrid;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.OneHot;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.ParallelStack;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Permute;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Rank;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Repeat;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Reshape;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Size;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Slice;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Stack;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.StridedSlice;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Tile;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Transpose;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.Unstack;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.shape.ZerosLike;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
 | 
					import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
 | 
					import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
 | 
					import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
 | 
				
			||||||
@ -77,37 +229,165 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
 | 
				
			|||||||
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Abs;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Floor;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Round;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.same.Square;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
 | 
					import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
 | 
					import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
 | 
					import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
 | 
					import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.impl.*;
 | 
					import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.Range;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.Shape;
 | 
					import org.nd4j.linalg.api.shape.Shape;
 | 
				
			||||||
import org.nd4j.linalg.indexing.conditions.Condition;
 | 
					import org.nd4j.linalg.indexing.conditions.Condition;
 | 
				
			||||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
					import org.nd4j.linalg.util.ArrayUtil;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.lang.reflect.Method;
 | 
					 | 
				
			||||||
import java.util.Arrays;
 | 
					 | 
				
			||||||
import java.util.HashMap;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
import java.util.Map;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
@ -1611,11 +1891,24 @@ public class DifferentialFunctionFactory {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable logSoftmax(SDVariable i_v, int dimension) {
 | 
				
			||||||
 | 
					        validateDifferentialFunctionsameDiff(i_v);
 | 
				
			||||||
 | 
					        return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) {
 | 
					    public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) {
 | 
				
			||||||
        validateDifferentialFunctionsameDiff(arg);
 | 
					        validateDifferentialFunctionsameDiff(arg);
 | 
				
			||||||
        return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable();
 | 
					        return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) {
 | 
				
			||||||
 | 
					        validateDifferentialFunctionsameDiff(arg);
 | 
				
			||||||
 | 
					        return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) {
 | 
					    public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) {
 | 
				
			||||||
        return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable();
 | 
					        return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -2296,6 +2589,22 @@ public class DifferentialFunctionFactory {
 | 
				
			|||||||
        return tile(func, ArrayUtil.toInts(input.getShape()));
 | 
					        return tile(func, ArrayUtil.toInts(input.getShape()));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable enter(SDVariable x, String frameName){
 | 
				
			||||||
 | 
					        return new Enter(sameDiff, frameName, x).outputVariable();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable enter(SDVariable x, String frameName, boolean isConstant){
 | 
				
			||||||
 | 
					        return new Enter(sameDiff, frameName, x, isConstant).outputVariable();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable exit(SDVariable x){
 | 
				
			||||||
 | 
					        return new Exit(sameDiff, x).outputVariable();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public SDVariable nextIteration(SDVariable x){
 | 
				
			||||||
 | 
					        return new NextIteration(sameDiff, x).outputVariable();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public String toString() {
 | 
					    public String toString() {
 | 
				
			||||||
        return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
 | 
					        return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					/*******************************************************************************
 | 
				
			||||||
 | 
					 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 * under the License.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Internal interface used to apply a transform to any arguments used within a certain block
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Intended for internal use only.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Managed with {@link SameDiff#addArgumentInterceptor(ArgumentInterceptor)}, {@link SameDiff#removeArgumentInterceptor()},
 | 
				
			||||||
 | 
					 * {@link SameDiff#pauseArgumentInterceptor()}, and {@link SameDiff#unpauseArgumentInterceptor()}
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					public interface ArgumentInterceptor {
 | 
				
			||||||
 | 
					    SDVariable intercept(SDVariable argument);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -16,6 +16,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.autodiff.samediff;
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Objects;
 | 
				
			||||||
import lombok.*;
 | 
					import lombok.*;
 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
import onnx.OnnxProto3;
 | 
					import onnx.OnnxProto3;
 | 
				
			||||||
@ -91,7 +92,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
 | 
				
			|||||||
        Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
 | 
					        Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String nameScope = sameDiff.currentNameScope();
 | 
					        String nameScope = sameDiff.currentNameScope();
 | 
				
			||||||
        if(nameScope != null){
 | 
					        if(nameScope != null && !varName.startsWith(nameScope + "/")){
 | 
				
			||||||
            varName = nameScope + "/" + varName;
 | 
					            varName = nameScope + "/" + varName;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1785,26 +1786,6 @@ public class SDVariable extends DifferentialFunction implements Serializable {
 | 
				
			|||||||
                (variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")";
 | 
					                (variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")";
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					 | 
				
			||||||
    public boolean equals(Object o) {
 | 
					 | 
				
			||||||
        if (this == o) return true;
 | 
					 | 
				
			||||||
        if (o == null || getClass() != o.getClass()) return false;
 | 
					 | 
				
			||||||
        if (!super.equals(o)) return false;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable that = (SDVariable) o;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (varName != null ? !varName.equals(that.varName) : that.varName != null) return false;
 | 
					 | 
				
			||||||
        return weightInitScheme != null ? weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @Override
 | 
					 | 
				
			||||||
    public int hashCode() {
 | 
					 | 
				
			||||||
        int result = super.hashCode();
 | 
					 | 
				
			||||||
        result = 31 * result + (varName != null ? varName.hashCode() : 0);
 | 
					 | 
				
			||||||
        result = 31 * result + (weightInitScheme != null ? weightInitScheme.hashCode() : 0);
 | 
					 | 
				
			||||||
        return result;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String onnxName() {
 | 
					    public String onnxName() {
 | 
				
			||||||
        throw new NoOpNameFoundException("No onnx op opName found for " +  opName());
 | 
					        throw new NoOpNameFoundException("No onnx op opName found for " +  opName());
 | 
				
			||||||
@ -1965,5 +1946,36 @@ public class SDVariable extends DifferentialFunction implements Serializable {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        return x;
 | 
					        return x;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
   
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public boolean equals(Object o) {
 | 
				
			||||||
 | 
					        if (this == o) {
 | 
				
			||||||
 | 
					            return true;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (!(o instanceof SDVariable)) {
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable that = (SDVariable) o;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (!Objects.equals(varName, that.varName)) {
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (variableType != that.variableType) {
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if(sameDiff != that.sameDiff){
 | 
				
			||||||
 | 
					            return false;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return dataType == that.dataType;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public int hashCode() {
 | 
				
			||||||
 | 
					        int result = super.hashCode();
 | 
				
			||||||
 | 
					        result = 31 * result + (varName != null ? varName.hashCode() : 0);
 | 
				
			||||||
 | 
					        result = 31 * result + (variableType != null ? variableType.hashCode() : 0);
 | 
				
			||||||
 | 
					        result = 31 * result + (dataType != null ? dataType.hashCode() : 0);
 | 
				
			||||||
 | 
					        return result;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
 | 
				
			|||||||
import org.nd4j.linalg.api.ops.impl.controlflow.If;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.If;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.controlflow.While;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.While;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
 | 
					import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
 | 
				
			||||||
@ -246,6 +247,14 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
    private boolean resolvedVariables = false;
 | 
					    private boolean resolvedVariables = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Getter
 | 
				
			||||||
 | 
					    private Stack<ArgumentInterceptor> argumentInterceptors = new Stack<>();
 | 
				
			||||||
 | 
					    @Getter
 | 
				
			||||||
 | 
					    private Set<ArgumentInterceptor> pausedArgumentInterceptors = new HashSet<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private Set<String> blockNames = new HashSet<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Getter
 | 
					    @Getter
 | 
				
			||||||
    @Setter
 | 
					    @Setter
 | 
				
			||||||
    boolean logExecution = true;
 | 
					    boolean logExecution = true;
 | 
				
			||||||
@ -472,7 +481,10 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        if(scope == null){
 | 
					        if(scope == null){
 | 
				
			||||||
            return name;
 | 
					            return name;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        return scope + "/" + name;
 | 
					        if(!name.startsWith(scope + "/"))
 | 
				
			||||||
 | 
					            return scope + "/" + name;
 | 
				
			||||||
 | 
					        else
 | 
				
			||||||
 | 
					            return name;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //Intentionally package private
 | 
					    //Intentionally package private
 | 
				
			||||||
@ -533,6 +545,24 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public List<SameDiffOp> getOpsInScope(NameScope scope){
 | 
				
			||||||
 | 
					        ArrayList<SameDiffOp> ops = new ArrayList<>();
 | 
				
			||||||
 | 
					        for(SameDiffOp v : this.ops.values()){
 | 
				
			||||||
 | 
					            if(v.getName().startsWith(scope.getName()))
 | 
				
			||||||
 | 
					                ops.add(v);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return ops;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public List<SDVariable> getVariablesInScope(NameScope scope){
 | 
				
			||||||
 | 
					        ArrayList<SDVariable> vars = new ArrayList<>();
 | 
				
			||||||
 | 
					        for(SDVariable v : variables()){
 | 
				
			||||||
 | 
					            if(v.getVarName().startsWith(scope.getName()))
 | 
				
			||||||
 | 
					                vars.add(v);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return vars;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * @param sameDiff
 | 
					     * @param sameDiff
 | 
				
			||||||
     * @return
 | 
					     * @return
 | 
				
			||||||
@ -1109,6 +1139,19 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Remove a property to resolve added with {@link #addPropertyToResolve(DifferentialFunction, String)}
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param forFunction the function to add the property to resolve for
 | 
				
			||||||
 | 
					     * @param arrayName   the array name
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) {
 | 
				
			||||||
 | 
					        if (propertiesToResolve.containsKey(forFunction.getOwnName())) {
 | 
				
			||||||
 | 
					            List<String> newVal = propertiesToResolve.get(forFunction.getOwnName());
 | 
				
			||||||
 | 
					            newVal.remove(arrayName);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Return the properties to resolve for the given function.
 | 
					     * Return the properties to resolve for the given function.
 | 
				
			||||||
     * This is typically used right before execution in model import in
 | 
					     * This is typically used right before execution in model import in
 | 
				
			||||||
@ -1272,6 +1315,92 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Add a new argument interceptor to the interceptor stack
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * When a op is added with arguments, most recent argument interceptor is called on it.
 | 
				
			||||||
 | 
					     * If ops are added in that interceptor, the next most recent will be called on their args, and so on.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param interceptor  the argument interceptor to add
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
 | 
				
			||||||
 | 
					        argumentInterceptors.push(interceptor);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor){
 | 
				
			||||||
 | 
					        return pausedArgumentInterceptors.contains(interceptor);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private ArgumentInterceptor getArgumentInterceptorToUse(){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(argumentInterceptors.isEmpty())
 | 
				
			||||||
 | 
					            return null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ArgumentInterceptor use = argumentInterceptors.peek();
 | 
				
			||||||
 | 
					        int i = 1;
 | 
				
			||||||
 | 
					        while(isArgumentInterceptorPaused(use)){
 | 
				
			||||||
 | 
					            if(argumentInterceptors.size() - i < 0)
 | 
				
			||||||
 | 
					                return null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            use = argumentInterceptors.elementAt(argumentInterceptors.size() - i);
 | 
				
			||||||
 | 
					            i++;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return use;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Remote the top (most recently added) argument interceptor
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void removeArgumentInterceptor(){
 | 
				
			||||||
 | 
					        if(!argumentInterceptors.isEmpty())
 | 
				
			||||||
 | 
					            argumentInterceptors.pop();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Pause the top (most recently added) argument interceptor
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void pauseArgumentInterceptor(){
 | 
				
			||||||
 | 
					        pausedArgumentInterceptors.add(argumentInterceptors.peek());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Pause the given argument interceptor
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param interceptor  the argument interceptor to pause
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
 | 
				
			||||||
 | 
					        pausedArgumentInterceptors.add(interceptor);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Unpause the top (most recently added) argument interceptor
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void unpauseArgumentInterceptor(){
 | 
				
			||||||
 | 
					        pausedArgumentInterceptors.remove(argumentInterceptors.peek());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Unpause the top given argument interceptor
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param interceptor  the argument interceptor to unpause
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
 | 
				
			||||||
 | 
					        pausedArgumentInterceptors.remove(interceptor);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Adds incoming arguments for the specified differential function to the graph
 | 
					     * Adds incoming arguments for the specified differential function to the graph
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
@ -1279,6 +1408,17 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
     * @param function  Function
 | 
					     * @param function  Function
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public void addArgsFor(String[] variables, DifferentialFunction function) {
 | 
					    public void addArgsFor(String[] variables, DifferentialFunction function) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ArgumentInterceptor interceptor = getArgumentInterceptorToUse();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(interceptor != null) {
 | 
				
			||||||
 | 
					            pauseArgumentInterceptor(interceptor);
 | 
				
			||||||
 | 
					            for (int i = 0; i < variables.length; i++) {
 | 
				
			||||||
 | 
					                variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName();
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            unpauseArgumentInterceptor(interceptor);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (function.getOwnName() == null)
 | 
					        if (function.getOwnName() == null)
 | 
				
			||||||
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
 | 
					            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1309,7 +1449,6 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Adds incoming arguments for the specified differential function to the graph
 | 
					     * Adds incoming arguments for the specified differential function to the graph
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
@ -1317,6 +1456,7 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
     * @param function  Function
 | 
					     * @param function  Function
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
 | 
					    public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        String[] varNames = new String[variables.length];
 | 
					        String[] varNames = new String[variables.length];
 | 
				
			||||||
        for (int i = 0; i < varNames.length; i++) {
 | 
					        for (int i = 0; i < varNames.length; i++) {
 | 
				
			||||||
            if (variables[i] == null)
 | 
					            if (variables[i] == null)
 | 
				
			||||||
@ -1326,6 +1466,58 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        addArgsFor(varNames, function);
 | 
					        addArgsFor(varNames, function);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Replaces the argument at i with newArg for function
 | 
				
			||||||
 | 
					     * Does not use (or remove) ArgumentInterceptor stuff
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Preconditions.checkArgument(i < function.args().length, "Index out of range: function " +
 | 
				
			||||||
 | 
					                function.getOwnName() + " only has " + function.args().length + " args but you are trying" +
 | 
				
			||||||
 | 
					                "to replace the argument at " + i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        String oldName = function.arg(i).getVarName();
 | 
				
			||||||
 | 
					        String newName = newArg.getVarName();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()){
 | 
				
			||||||
 | 
					            boolean otherPlaceholders = false;
 | 
				
			||||||
 | 
					            for(int j = 0 ; j < function.argNames().length ; j++){
 | 
				
			||||||
 | 
					                if(j == i)
 | 
				
			||||||
 | 
					                    continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(function.arg(j).isPlaceHolder())
 | 
				
			||||||
 | 
					                    otherPlaceholders = true;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if(!otherPlaceholders)
 | 
				
			||||||
 | 
					                placeHolderFunctions.remove(function.getOwnName());
 | 
				
			||||||
 | 
					        } else if(!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()){
 | 
				
			||||||
 | 
					            if(!placeHolderFunctions.contains(function.getOwnName()))
 | 
				
			||||||
 | 
					                placeHolderFunctions.add(function.getOwnName());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        List<String> oldArgs = ops.get(function.getOwnName()).getInputsToOp();
 | 
				
			||||||
 | 
					        oldArgs = new ArrayList<>(oldArgs);
 | 
				
			||||||
 | 
					        oldArgs.set(i, newName);
 | 
				
			||||||
 | 
					        ops.get(function.getOwnName()).setInputsToOp(oldArgs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        List<String> funcs = this.variables.get(newName).getInputsForOp();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (funcs == null) {
 | 
				
			||||||
 | 
					            funcs = new ArrayList<>();
 | 
				
			||||||
 | 
					            this.variables.get(newName).setInputsForOp(funcs);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if(!funcs.contains(function.getOwnName()))  //Avoid duplicates for function names.
 | 
				
			||||||
 | 
					            funcs.add(function.getOwnName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        List<String> oldFuncs = this.variables.get(oldName).getInputsForOp();
 | 
				
			||||||
 | 
					        if(oldFuncs != null) {
 | 
				
			||||||
 | 
					            if(!ArrayUtils.contains(function.argNames(), oldName))
 | 
				
			||||||
 | 
					                oldFuncs.remove(function.getOwnName());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Get the differential function (if any) that this variable is the output for
 | 
					     * Get the differential function (if any) that this variable is the output for
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
@ -1519,6 +1711,7 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed
 | 
					                //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed
 | 
				
			||||||
                // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here
 | 
					                // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here
 | 
				
			||||||
 | 
					                // This applies to SameDiff while loops as well
 | 
				
			||||||
                if(o.getOp() instanceof Switch){
 | 
					                if(o.getOp() instanceof Switch){
 | 
				
			||||||
                    continue;
 | 
					                    continue;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
@ -2239,6 +2432,7 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        if (name == null || name.length() < 1)
 | 
					        if (name == null || name.length() < 1)
 | 
				
			||||||
            name = getNewVarName();
 | 
					            name = getNewVarName();
 | 
				
			||||||
        SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
 | 
					        SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
 | 
				
			||||||
 | 
					        name = v.getVarName();
 | 
				
			||||||
        variables.put(name, Variable.builder().name(name).variable(v).build());
 | 
					        variables.put(name, Variable.builder().name(name).variable(v).build());
 | 
				
			||||||
        constantArrays.put(name, new DeviceLocalNDArray(constant));
 | 
					        constantArrays.put(name, new DeviceLocalNDArray(constant));
 | 
				
			||||||
        return v;
 | 
					        return v;
 | 
				
			||||||
@ -2305,6 +2499,7 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
    public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
 | 
					    public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
 | 
				
			||||||
                             org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
 | 
					                             org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
 | 
				
			||||||
        String withScope = nameWithScope(name);
 | 
					        String withScope = nameWithScope(name);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (variables.containsKey(withScope)) {
 | 
					        if (variables.containsKey(withScope)) {
 | 
				
			||||||
            if(nameScopes.isEmpty()){
 | 
					            if(nameScopes.isEmpty()){
 | 
				
			||||||
                throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
 | 
					                throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
 | 
				
			||||||
@ -3414,12 +3609,9 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Creates a while statement
 | 
					     * @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param sameDiffConditional
 | 
					 | 
				
			||||||
     * @param loopBody
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
    public While whileStatement(SameDiffConditional sameDiffConditional,
 | 
					    public While whileStatement(SameDiffConditional sameDiffConditional,
 | 
				
			||||||
                                SameDiffFunctionDefinition conditionBody,
 | 
					                                SameDiffFunctionDefinition conditionBody,
 | 
				
			||||||
                                SameDiffFunctionDefinition loopBody
 | 
					                                SameDiffFunctionDefinition loopBody
 | 
				
			||||||
@ -3435,11 +3627,9 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * @param conditional
 | 
					     * @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
 | 
				
			||||||
     * @param trueBody
 | 
					 | 
				
			||||||
     * @param falseBody
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
 | 
					    @Deprecated
 | 
				
			||||||
    public If ifStatement(SameDiffConditional conditional,
 | 
					    public If ifStatement(SameDiffConditional conditional,
 | 
				
			||||||
                          SameDiffFunctionDefinition conditionBody,
 | 
					                          SameDiffFunctionDefinition conditionBody,
 | 
				
			||||||
                          SameDiffFunctionDefinition trueBody,
 | 
					                          SameDiffFunctionDefinition trueBody,
 | 
				
			||||||
@ -5466,5 +5656,27 @@ public class SameDiff extends SDBaseOps {
 | 
				
			|||||||
        return out;
 | 
					        return out;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * For internal use only.
 | 
				
			||||||
 | 
					     * Creates a new discinct block name from baseName.
 | 
				
			||||||
 | 
					     * Block names are used by If and While
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public String newBlockName(String baseName){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(baseName == null)
 | 
				
			||||||
 | 
					            return null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(!blockNames.contains(baseName)){
 | 
				
			||||||
 | 
					            blockNames.add(baseName);
 | 
				
			||||||
 | 
					            return baseName;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            int i = 1;
 | 
				
			||||||
 | 
					            while(blockNames.contains(baseName + "_" + i)){
 | 
				
			||||||
 | 
					                i++;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            blockNames.add(baseName + "_" + i);
 | 
				
			||||||
 | 
					            return baseName + "_" + i;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					/*******************************************************************************
 | 
				
			||||||
 | 
					 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 * under the License.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A basic SameDiff lambda, used in while loop creation (the body).
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					public interface SameDiffLambda {
 | 
				
			||||||
 | 
					    SDVariable[] define(SameDiff sameDiff, SDVariable[] inputs);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					/*******************************************************************************
 | 
				
			||||||
 | 
					 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 * under the License.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A SameDiff lambda with only one output and no arguments.  Used in if condition creation (the condition and bodies).
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					public interface SameDiffNoArgSingleLambda {
 | 
				
			||||||
 | 
					    SDVariable define(SameDiff sameDiff);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					/*******************************************************************************
 | 
				
			||||||
 | 
					 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
 | 
					 * terms of the Apache License, Version 2.0 which is available at
 | 
				
			||||||
 | 
					 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
				
			||||||
 | 
					 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
				
			||||||
 | 
					 * License for the specific language governing permissions and limitations
 | 
				
			||||||
 | 
					 * under the License.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
 | 
					 ******************************************************************************/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A SameDiff lambda with only one output, used in while loop creation (the condition).
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					public interface SameDiffSingleLambda {
 | 
				
			||||||
 | 
					    SDVariable define(SameDiff sameDiff, SDVariable[] inputs);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -16,12 +16,25 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.autodiff.samediff.ops;
 | 
					package org.nd4j.autodiff.samediff.ops;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.google.common.collect.Sets;
 | 
				
			||||||
 | 
					import java.util.HashMap;
 | 
				
			||||||
 | 
					import java.util.HashSet;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					import java.util.Set;
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
 | 
					import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.ArgumentInterceptor;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.NameScope;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
					import org.nd4j.autodiff.samediff.SDVariable;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
					import org.nd4j.autodiff.samediff.SameDiff;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.SameDiffLambda;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
 | 
				
			||||||
 | 
					import org.nd4j.autodiff.samediff.internal.SameDiffOp;
 | 
				
			||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
 | 
					import org.nd4j.linalg.api.blas.params.MMulTranspose;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
 | 
					import org.nd4j.linalg.api.ops.impl.shape.OneHot;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
 | 
				
			||||||
import org.nd4j.linalg.indexing.conditions.Condition;
 | 
					import org.nd4j.linalg.indexing.conditions.Condition;
 | 
				
			||||||
@ -3142,4 +3155,304 @@ public abstract class SDBaseOps {
 | 
				
			|||||||
        SDVariable ret = f().zerosLike(name, input);
 | 
					        SDVariable ret = f().zerosLike(name, input);
 | 
				
			||||||
        return updateVariableNameAndReference(ret, name);
 | 
					        return updateVariableNameAndReference(ret, name);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #any(String, SDVariable, int...)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable any(SDVariable x, int... dimensions){
 | 
				
			||||||
 | 
					        return any(null, x, dimensions);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    //TODO check any w/ no dimensions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Boolean or array reduction operation, optionally along specified dimensions
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param name   Name of the output variable
 | 
				
			||||||
 | 
					     * @param x   Input variable
 | 
				
			||||||
 | 
					     * @param dimensions    Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
 | 
				
			||||||
 | 
					     * @return Output variable: reduced array of rank (input rank - num dimensions)
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable any(String name, SDVariable x, int... dimensions){
 | 
				
			||||||
 | 
					        validateBool("any", x);
 | 
				
			||||||
 | 
					        SDVariable ret = f().any(x, dimensions);
 | 
				
			||||||
 | 
					        return updateVariableNameAndReference(ret, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #all(String, SDVariable, int...)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable all(SDVariable x, int... dimensions){
 | 
				
			||||||
 | 
					        return all(null, x, dimensions);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Boolean and array reduction operation, optionally along specified dimensions
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param name   Name of the output variable
 | 
				
			||||||
 | 
					     * @param x   Input variable
 | 
				
			||||||
 | 
					     * @param dimensions    Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
 | 
				
			||||||
 | 
					     * @return Output variable: reduced array of rank (input rank - num dimensions)
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable all(String name, SDVariable x, int... dimensions){
 | 
				
			||||||
 | 
					        validateBool("all", x);
 | 
				
			||||||
 | 
					        SDVariable ret = f().all(x, dimensions);
 | 
				
			||||||
 | 
					        return updateVariableNameAndReference(ret, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars,
 | 
				
			||||||
 | 
					            @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
 | 
				
			||||||
 | 
					        return whileLoop(null, null, loopVars, cond, body);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars,
 | 
				
			||||||
 | 
					            @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
 | 
				
			||||||
 | 
					        return whileLoop(null, loopName, loopVars, cond, body);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Note that cond and body lambdas are only called once to construct the graph.  The constructed graph is used for further iterations.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param outputNames  Names to give the output variables.  If null, doesn't rename
 | 
				
			||||||
 | 
					     * @param loopName  The name of the loop block and frame (must be unique).  If null, uses "if"
 | 
				
			||||||
 | 
					     * @param loopVars  Loop variables' inputs
 | 
				
			||||||
 | 
					     * @param cond  A lambda evaluating to the loop condition
 | 
				
			||||||
 | 
					     * @param body  A lambda doing the loop operation and returning the new loop variable values
 | 
				
			||||||
 | 
					     * @return  The values of the loop variables once condition is false
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars,
 | 
				
			||||||
 | 
					            @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final String frameName = sd().newBlockName(loopName == null ? "while" : loopName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope loopScope = sd().withNameScope(frameName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable[] entered = new SDVariable[loopVars.length];
 | 
				
			||||||
 | 
					        for(int i = 0 ; i < loopVars.length ; i++){
 | 
				
			||||||
 | 
					            entered[i] = f().enter(loopVars[i], frameName);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //counter = SD.f().enter(counter, frameName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable[] merged = new SDVariable[loopVars.length];
 | 
				
			||||||
 | 
					        Merge[] mergeOps = new Merge[loopVars.length];
 | 
				
			||||||
 | 
					        for(int i = 0 ; i < loopVars.length ; i++){
 | 
				
			||||||
 | 
					            // the second arg will later be replaced with the output of NextIteration
 | 
				
			||||||
 | 
					            // but that isn't available yet (and can't be, as it depends on this)
 | 
				
			||||||
 | 
					            mergeOps[i] = new Merge(sd(), entered[i], entered[i]);
 | 
				
			||||||
 | 
					            merged[i] = mergeOps[i].outputVariable();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //Merge counterMerge = new Merge(SD, counter, counter);
 | 
				
			||||||
 | 
					        //counter = counterMerge.outputVariable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope condScope = sd().withNameScope("cond");
 | 
				
			||||||
 | 
					        SDVariable cond_result = cond.define(sd(), merged);
 | 
				
			||||||
 | 
					        condScope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (cond_result.dataType() != DataType.BOOL)
 | 
				
			||||||
 | 
					            throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Set<String> alreadyEntered = Sets.newHashSet();
 | 
				
			||||||
 | 
					        SDVariable[] trueSwitches = new SDVariable[loopVars.length];
 | 
				
			||||||
 | 
					        SDVariable[] exits = new SDVariable[loopVars.length];
 | 
				
			||||||
 | 
					        for(int i = 0 ; i < loopVars.length ; i++){
 | 
				
			||||||
 | 
					            SDVariable[] s = f().switchOp(merged[i], cond_result);
 | 
				
			||||||
 | 
					            trueSwitches[i] = s[1];
 | 
				
			||||||
 | 
					            alreadyEntered.add(s[1].getVarName());
 | 
				
			||||||
 | 
					            exits[i] = f().exit(s[0]);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //SDVariable[] cs = SD.f().switchOp(counter, cond_result);
 | 
				
			||||||
 | 
					        //SDVariable counterExit = SD.f().exit(cs[0]);
 | 
				
			||||||
 | 
					        //counter = cs[1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Set<String> declared = Sets.newHashSet(sd().variableMap().keySet());
 | 
				
			||||||
 | 
					        final Map<String, SDVariable> done = new HashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sd().addArgumentInterceptor(new ArgumentInterceptor() {
 | 
				
			||||||
 | 
					            @Override
 | 
				
			||||||
 | 
					            public SDVariable intercept(SDVariable argument) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(!declared.contains(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return argument;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(alreadyEntered.contains(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return argument;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if(done.containsKey(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return done.get(argument.getVarName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                SDVariable e = f().enter(argument, frameName, true);
 | 
				
			||||||
 | 
					                done.put(argument.getVarName(), e);
 | 
				
			||||||
 | 
					                return e;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope bodyScope = sd().withNameScope("body");
 | 
				
			||||||
 | 
					        SDVariable[] outs = body.define(sd(), trueSwitches);
 | 
				
			||||||
 | 
					        bodyScope.close();
 | 
				
			||||||
 | 
					        sd().removeArgumentInterceptor();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //counter.add(1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for(int i = 0 ; i < loopVars.length ; i++){
 | 
				
			||||||
 | 
					            SDVariable n = f().nextIteration(outs[i]);
 | 
				
			||||||
 | 
					            mergeOps[i].replaceArg(1,n);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //counterMerge.replaceArg(1, counter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        loopScope.close();
 | 
				
			||||||
 | 
					        return updateVariableNamesAndReferences(exits, outputNames);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond,
 | 
				
			||||||
 | 
					            @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
 | 
				
			||||||
 | 
					        return ifCond(null, null, cond, trueBody, falseBody);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond,
 | 
				
			||||||
 | 
					            @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
 | 
				
			||||||
 | 
					        return ifCond(null, ifName, cond, trueBody, falseBody);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge)
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Note that cond and body lambdas are only called once to construct the graph.  The constructed graph is used to evaluate.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param outputName Name to give the output variable.  If null, doesn't rename
 | 
				
			||||||
 | 
					     * @param ifName  The name of the if block.  If null, uses "if"
 | 
				
			||||||
 | 
					     * @param cond  A lambda evaluating to the if condition
 | 
				
			||||||
 | 
					     * @param trueBody  A lambda to be executed if cond is true (the if block)
 | 
				
			||||||
 | 
					     * @param falseBody  A lambda to be executed if cond is false (the else block)
 | 
				
			||||||
 | 
					     * @return The value of trueBody if cond is true, or falseBody if it isn't
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond,
 | 
				
			||||||
 | 
					            @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ifName = sd().newBlockName(ifName == null ? "if" : ifName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope ifScope = sd().withNameScope(ifName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope condScope = sd().withNameScope("cond");
 | 
				
			||||||
 | 
					        final SDVariable pred = cond.define(sd());
 | 
				
			||||||
 | 
					        condScope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (pred.dataType() != DataType.BOOL) {
 | 
				
			||||||
 | 
					            //cleanup partially added block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for(SDVariable v : sd().getVariablesInScope(ifScope))
 | 
				
			||||||
 | 
					                sd().getVariables().remove(v.getVarName());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
 | 
				
			||||||
 | 
					                for(String in : op.getInputsToOp()){
 | 
				
			||||||
 | 
					                    sd().removeArgFromFunction(in, op.getOp());
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                sd().getOps().remove(op.getName());
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            throw new IllegalStateException("Can not use " + pred.getVarName()
 | 
				
			||||||
 | 
					                    + " as the condition of an If statement, the condition must be a boolean.");
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Map<String, SDVariable[]> switches = new HashMap<>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Set<String> declared = Sets.newHashSet(sd().variableMap().keySet());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sd().addArgumentInterceptor(new ArgumentInterceptor() {
 | 
				
			||||||
 | 
					            @Override
 | 
				
			||||||
 | 
					            public SDVariable intercept(SDVariable argument) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // if its declared in the if, we don't care acout it
 | 
				
			||||||
 | 
					                if(!declared.contains(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return argument;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // if we've already added a switch, move on
 | 
				
			||||||
 | 
					                if(switches.containsKey(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return switches.get(argument.getVarName())[1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                SDVariable[] s = f().switchOp(argument, pred);
 | 
				
			||||||
 | 
					                switches.put(argument.getVarName(), s);
 | 
				
			||||||
 | 
					                return s[1];
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					        NameScope trueScope = sd().withNameScope("trueBody");
 | 
				
			||||||
 | 
					        SDVariable trueOut = trueBody.define(sd());
 | 
				
			||||||
 | 
					        sd().removeArgumentInterceptor();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(declared.contains(trueOut.getVarName())) {
 | 
				
			||||||
 | 
					            SDVariable[] s = f().switchOp(trueOut, pred);
 | 
				
			||||||
 | 
					            switches.put(trueOut.getVarName(), s);
 | 
				
			||||||
 | 
					            trueOut = s[1];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        trueScope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        final Set<String> declared2 = Sets.newHashSet(sd().variableMap().keySet());
 | 
				
			||||||
 | 
					        sd().addArgumentInterceptor(new ArgumentInterceptor() {
 | 
				
			||||||
 | 
					            @Override
 | 
				
			||||||
 | 
					            public SDVariable intercept(SDVariable argument) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // if its declared in the if, we don't care acout it
 | 
				
			||||||
 | 
					                if(!declared2.contains(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return argument;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // if we've already added a switch, move on
 | 
				
			||||||
 | 
					                if(switches.containsKey(argument.getVarName()))
 | 
				
			||||||
 | 
					                    return switches.get(argument.getVarName())[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                SDVariable[] s = f().switchOp(argument, pred);
 | 
				
			||||||
 | 
					                switches.put(argument.getVarName(), s);
 | 
				
			||||||
 | 
					                return s[0];
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					        NameScope falseScope = sd().withNameScope("falseBody");
 | 
				
			||||||
 | 
					        SDVariable falseOut = falseBody.define(sd());
 | 
				
			||||||
 | 
					        sd().removeArgumentInterceptor();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if(declared2.contains(falseOut.getVarName())) {
 | 
				
			||||||
 | 
					            SDVariable[] s = f().switchOp(falseOut, pred);
 | 
				
			||||||
 | 
					            switches.put(falseOut.getVarName(), s);
 | 
				
			||||||
 | 
					            falseOut = s[0];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        falseScope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable output = f().merge(trueOut, falseOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ifScope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return updateVariableNameAndReference(output, outputName);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -411,6 +411,29 @@ public class SDNN extends SDOps {
 | 
				
			|||||||
        return updateVariableNameAndReference(ret, name);
 | 
					        return updateVariableNameAndReference(ret, name);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Log softmax activation
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param x Input variable
 | 
				
			||||||
 | 
					     * @return Output variable
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable logSoftmax(SDVariable x, int dimension) {
 | 
				
			||||||
 | 
					        return logSoftmax(null, x, dimension);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Log softmax activation
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param name Variable name
 | 
				
			||||||
 | 
					     * @param x    Input variable
 | 
				
			||||||
 | 
					     * @return Output variable
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable logSoftmax(String name, SDVariable x, int dimension) {
 | 
				
			||||||
 | 
					        validateFloatingPoint("log softmax", x);
 | 
				
			||||||
 | 
					        SDVariable ret = f().logSoftmax(x, dimension);
 | 
				
			||||||
 | 
					        return updateVariableNameAndReference(ret, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * Element-wise rectified linear function with specified cutoff:<br>
 | 
					     * Element-wise rectified linear function with specified cutoff:<br>
 | 
				
			||||||
     * out[i] = in[i] if in[i] >= cutoff
 | 
					     * out[i] = in[i] if in[i] >= cutoff
 | 
				
			||||||
@ -591,6 +614,28 @@ public class SDNN extends SDOps {
 | 
				
			|||||||
        return updateVariableNameAndReference(result, name);
 | 
					        return updateVariableNameAndReference(result, name);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Softmax activation
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param x Input variable
 | 
				
			||||||
 | 
					     * @return Output variable
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable softmax(SDVariable x, int dimension) {
 | 
				
			||||||
 | 
					        return softmax(null, x, dimension);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * Softmax activation
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param x Input variable
 | 
				
			||||||
 | 
					     * @return Output variable
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public SDVariable softmax(String name, SDVariable x, int dimension) {
 | 
				
			||||||
 | 
					        validateFloatingPoint("softmax", x);
 | 
				
			||||||
 | 
					        SDVariable result = f().softmax(x, dimension);
 | 
				
			||||||
 | 
					        return updateVariableNameAndReference(result, name);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * @param x
 | 
					     * @param x
 | 
				
			||||||
     * @return
 | 
					     * @return
 | 
				
			||||||
 | 
				
			|||||||
@ -17,36 +17,47 @@
 | 
				
			|||||||
package org.nd4j.autodiff.samediff.serde;
 | 
					package org.nd4j.autodiff.samediff.serde;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.flatbuffers.FlatBufferBuilder;
 | 
					import com.google.flatbuffers.FlatBufferBuilder;
 | 
				
			||||||
 | 
					import java.nio.ByteOrder;
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
 | 
					import java.util.HashMap;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
 | 
					import org.nd4j.autodiff.functions.DifferentialFunction;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
					 | 
				
			||||||
import org.nd4j.autodiff.samediff.VariableType;
 | 
					import org.nd4j.autodiff.samediff.VariableType;
 | 
				
			||||||
import org.nd4j.base.Preconditions;
 | 
					import org.nd4j.base.Preconditions;
 | 
				
			||||||
import org.nd4j.graph.*;
 | 
					import org.nd4j.graph.DataType;
 | 
				
			||||||
 | 
					import org.nd4j.graph.FlatArray;
 | 
				
			||||||
 | 
					import org.nd4j.graph.FlatNode;
 | 
				
			||||||
 | 
					import org.nd4j.graph.FlatProperties;
 | 
				
			||||||
 | 
					import org.nd4j.graph.IntPair;
 | 
				
			||||||
 | 
					import org.nd4j.graph.OpType;
 | 
				
			||||||
 | 
					import org.nd4j.graph.VarType;
 | 
				
			||||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
 | 
					import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
					import org.nd4j.linalg.api.ndarray.INDArray;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.*;
 | 
					import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.BaseReduceOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.CustomOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.ScalarOp;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.Shape;
 | 
					import org.nd4j.linalg.api.shape.Shape;
 | 
				
			||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
					import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
import org.nd4j.linalg.primitives.Pair;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.util.ArrayUtil;
 | 
					import org.nd4j.linalg.util.ArrayUtil;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.nio.ByteOrder;
 | 
					 | 
				
			||||||
import java.util.*;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public class FlatBuffersMapper {
 | 
					public class FlatBuffersMapper {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private FlatBuffersMapper(){ }
 | 
					    private FlatBuffersMapper() {
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * This method converts enums for DataType
 | 
					     * This method converts enums for DataType
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param type
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
 | 
					    public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
 | 
				
			||||||
        switch (type) {
 | 
					        switch (type) {
 | 
				
			||||||
@ -84,88 +95,87 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * This method converts enums for DataType
 | 
					     * This method converts enums for DataType
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param val
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
 | 
					    public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
 | 
				
			||||||
        if (val == DataType.FLOAT)
 | 
					        if (val == DataType.FLOAT) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.FLOAT;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.FLOAT;
 | 
				
			||||||
        else if (val == DataType.DOUBLE)
 | 
					        } else if (val == DataType.DOUBLE) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
 | 
				
			||||||
        else if (val == DataType.HALF)
 | 
					        } else if (val == DataType.HALF) {
 | 
				
			||||||
            return  org.nd4j.linalg.api.buffer.DataType.HALF;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.HALF;
 | 
				
			||||||
        else if (val == DataType.INT32)
 | 
					        } else if (val == DataType.INT32) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.INT;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.INT;
 | 
				
			||||||
        else if (val == DataType.INT64)
 | 
					        } else if (val == DataType.INT64) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.LONG;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.LONG;
 | 
				
			||||||
        else if (val == DataType.INT8)
 | 
					        } else if (val == DataType.INT8) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.BYTE;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.BYTE;
 | 
				
			||||||
        else if (val == DataType.BOOL)
 | 
					        } else if (val == DataType.BOOL) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.BOOL;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.BOOL;
 | 
				
			||||||
        else if (val == DataType.UINT8)
 | 
					        } else if (val == DataType.UINT8) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.UBYTE;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.UBYTE;
 | 
				
			||||||
        else if (val == DataType.INT16)
 | 
					        } else if (val == DataType.INT16) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.SHORT;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.SHORT;
 | 
				
			||||||
        else if (val == DataType.UTF8)
 | 
					        } else if (val == DataType.UTF8) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.UTF8;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.UTF8;
 | 
				
			||||||
        else if (val == DataType.UINT16)
 | 
					        } else if (val == DataType.UINT16) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.UINT16;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.UINT16;
 | 
				
			||||||
        else if (val == DataType.UINT32)
 | 
					        } else if (val == DataType.UINT32) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.UINT32;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.UINT32;
 | 
				
			||||||
        else if (val == DataType.UINT64)
 | 
					        } else if (val == DataType.UINT64) {
 | 
				
			||||||
            return org.nd4j.linalg.api.buffer.DataType.UINT64;
 | 
					            return org.nd4j.linalg.api.buffer.DataType.UINT64;
 | 
				
			||||||
        else
 | 
					        } else {
 | 
				
			||||||
            throw new RuntimeException("Unknown datatype: " + val);
 | 
					            throw new RuntimeException("Unknown datatype: " + val);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * This method return operation ID for given op name/type pair.
 | 
					     * This method return operation ID for given op name/type pair.
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param name
 | 
					 | 
				
			||||||
     * @param type
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static long getOpNum(String name, Op.Type type) {
 | 
					    public static long getOpNum(String name, Op.Type type) {
 | 
				
			||||||
        if (type == Op.Type.LOOP) {
 | 
					        if (type == Op.Type.LOOP) {
 | 
				
			||||||
            return 0;
 | 
					            return 0;
 | 
				
			||||||
        } else if (type == Op.Type.RETURN) {
 | 
					        } else if (type == Op.Type.RETURN) {
 | 
				
			||||||
            return 40;
 | 
					            return 40;
 | 
				
			||||||
        } else if (type == Op.Type.IF) {
 | 
					 | 
				
			||||||
            return 30;
 | 
					 | 
				
			||||||
        } else if (type == Op.Type.CONDITIONAL) {
 | 
					        } else if (type == Op.Type.CONDITIONAL) {
 | 
				
			||||||
            return 10;
 | 
					            return 10;
 | 
				
			||||||
        } else if (type == Op.Type.MERGE) {
 | 
					 | 
				
			||||||
            return 60L;
 | 
					 | 
				
			||||||
        } else if (type == Op.Type.LOOP_COND) {
 | 
					        } else if (type == Op.Type.LOOP_COND) {
 | 
				
			||||||
            return 70L;
 | 
					            return 70L;
 | 
				
			||||||
        } else if (type == Op.Type.NEXT_ITERATION) {
 | 
					        } else if (type == Type.LOGIC) {
 | 
				
			||||||
            return 80L;
 | 
					            switch (name) {
 | 
				
			||||||
        } else if (type == Op.Type.EXIT) {
 | 
					                case Enter.OP_NAME:
 | 
				
			||||||
            return 90L;
 | 
					                    return Enter.OP_NUM;
 | 
				
			||||||
        } else if (type == Op.Type.ENTER) {
 | 
					                case Exit.OP_NAME:
 | 
				
			||||||
            return 100L;
 | 
					                    return Exit.OP_NUM;
 | 
				
			||||||
 | 
					                case NextIteration.OP_NAME:
 | 
				
			||||||
 | 
					                    return NextIteration.OP_NUM;
 | 
				
			||||||
 | 
					                case Merge.OP_NAME:
 | 
				
			||||||
 | 
					                    return Merge.OP_NUM;
 | 
				
			||||||
 | 
					                case Switch.OP_NAME:
 | 
				
			||||||
 | 
					                    return Switch.OP_NUM;
 | 
				
			||||||
 | 
					                default:
 | 
				
			||||||
 | 
					                    throw new IllegalStateException("Unknown LOGIC op with name: " + name);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        } else if (type == Op.Type.CUSTOM) {
 | 
					        } else if (type == Op.Type.CUSTOM) {
 | 
				
			||||||
            val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
 | 
					            val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
 | 
				
			||||||
            if (name2 == null) {
 | 
					            if (name2 == null) {
 | 
				
			||||||
                val name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
 | 
					                val name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
 | 
				
			||||||
                if (name3 == null)
 | 
					                if (name3 == null) {
 | 
				
			||||||
                    return 0;
 | 
					                    return 0;
 | 
				
			||||||
                else
 | 
					                } else {
 | 
				
			||||||
                    return name3.getHash();
 | 
					                    return name3.getHash();
 | 
				
			||||||
            } else
 | 
					                }
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
                return name2.getHash();
 | 
					                return name2.getHash();
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash();
 | 
					            //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
                DifferentialFunction op =  DifferentialFunctionClassHolder.getInstance().getInstance(name);
 | 
					                DifferentialFunction op = DifferentialFunctionClassHolder.getInstance().getInstance(name);
 | 
				
			||||||
                return  op.opNum();
 | 
					                return op.opNum();
 | 
				
			||||||
            } catch (Exception e) {
 | 
					            } catch (Exception e) {
 | 
				
			||||||
                throw new RuntimeException("Could not find op number for operation: [" + name + "]",e);
 | 
					                throw new RuntimeException("Could not find op number for operation: [" + name + "]", e);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -212,7 +222,7 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
            case OpType.RANDOM:
 | 
					            case OpType.RANDOM:
 | 
				
			||||||
                return Op.Type.RANDOM;
 | 
					                return Op.Type.RANDOM;
 | 
				
			||||||
            case OpType.LOGIC:
 | 
					            case OpType.LOGIC:
 | 
				
			||||||
                return Op.Type.META;
 | 
					                return Type.LOGIC;
 | 
				
			||||||
            case OpType.CUSTOM:
 | 
					            case OpType.CUSTOM:
 | 
				
			||||||
                return Op.Type.CUSTOM;
 | 
					                return Op.Type.CUSTOM;
 | 
				
			||||||
            case OpType.PAIRWISE:
 | 
					            case OpType.PAIRWISE:
 | 
				
			||||||
@ -269,15 +279,11 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
                return OpType.INDEX_REDUCE;
 | 
					                return OpType.INDEX_REDUCE;
 | 
				
			||||||
            case RANDOM:
 | 
					            case RANDOM:
 | 
				
			||||||
                return OpType.RANDOM;
 | 
					                return OpType.RANDOM;
 | 
				
			||||||
            case MERGE:
 | 
					 | 
				
			||||||
            case CONDITIONAL:
 | 
					            case CONDITIONAL:
 | 
				
			||||||
            case LOOP:
 | 
					            case LOOP:
 | 
				
			||||||
            case RETURN:
 | 
					            case RETURN:
 | 
				
			||||||
            case ENTER:
 | 
					 | 
				
			||||||
            case EXIT:
 | 
					 | 
				
			||||||
            case NEXT_ITERATION:
 | 
					 | 
				
			||||||
            case LOOP_COND:
 | 
					            case LOOP_COND:
 | 
				
			||||||
            case IF:
 | 
					            case LOGIC:
 | 
				
			||||||
                return OpType.LOGIC;
 | 
					                return OpType.LOGIC;
 | 
				
			||||||
            case CUSTOM:
 | 
					            case CUSTOM:
 | 
				
			||||||
                return OpType.CUSTOM;
 | 
					                return OpType.CUSTOM;
 | 
				
			||||||
@ -295,88 +301,87 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * This method just converts enums
 | 
					     * This method just converts enums
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @param val
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static ByteOrder getOrderFromByte(byte val) {
 | 
					    public static ByteOrder getOrderFromByte(byte val) {
 | 
				
			||||||
        if (val == org.nd4j.graph.ByteOrder.LE)
 | 
					        if (val == org.nd4j.graph.ByteOrder.LE) {
 | 
				
			||||||
            return ByteOrder.LITTLE_ENDIAN;
 | 
					            return ByteOrder.LITTLE_ENDIAN;
 | 
				
			||||||
        else
 | 
					        } else {
 | 
				
			||||||
            return ByteOrder.BIG_ENDIAN;
 | 
					            return ByteOrder.BIG_ENDIAN;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * This method returns current byte order for this JVM as libnd4j enum
 | 
					     * This method returns current byte order for this JVM as libnd4j enum
 | 
				
			||||||
     *
 | 
					 | 
				
			||||||
     * @return
 | 
					 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
    public static byte getOrderAsByte() {
 | 
					    public static byte getOrderAsByte() {
 | 
				
			||||||
        if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN))
 | 
					        if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
 | 
				
			||||||
            return org.nd4j.graph.ByteOrder.BE;
 | 
					            return org.nd4j.graph.ByteOrder.BE;
 | 
				
			||||||
        else
 | 
					        } else {
 | 
				
			||||||
            return org.nd4j.graph.ByteOrder.LE;
 | 
					            return org.nd4j.graph.ByteOrder.LE;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static DifferentialFunction fromFlatNode(FlatNode fn){
 | 
					    public static DifferentialFunction fromFlatNode(FlatNode fn) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int id = fn.id();               //ID of the node
 | 
					        int id = fn.id();               //ID of the node
 | 
				
			||||||
        String name = fn.name();        //Name of the node, NOT the name of the op
 | 
					        String name = fn.name();        //Name of the node, NOT the name of the op
 | 
				
			||||||
        Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType());
 | 
					        Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType());
 | 
				
			||||||
        long opNum = fn.opNum();        //Op num: hash for custom, number for legacy
 | 
					        long opNum = fn.opNum();        //Op num: hash for custom, number for legacy
 | 
				
			||||||
        int[] input = new int[fn.inputLength()];
 | 
					        int[] input = new int[fn.inputLength()];
 | 
				
			||||||
        for( int i=0; i<input.length; i++ ){
 | 
					        for (int i = 0; i < input.length; i++) {
 | 
				
			||||||
            input[i] = fn.input(i);
 | 
					            input[i] = fn.input(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
 | 
					        IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
 | 
				
			||||||
        for( int i=0; i<inputPaired.length; i++ ){
 | 
					        for (int i = 0; i < inputPaired.length; i++) {
 | 
				
			||||||
            inputPaired[i] = fn.inputPaired(i);
 | 
					            inputPaired[i] = fn.inputPaired(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        int[] output = new int[fn.outputLength()];
 | 
					        int[] output = new int[fn.outputLength()];
 | 
				
			||||||
        for( int i=0; i<output.length; i++ ){
 | 
					        for (int i = 0; i < output.length; i++) {
 | 
				
			||||||
            output[i] = fn.output(i);
 | 
					            output[i] = fn.output(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        double[] extraParams = new double[fn.extraParamsLength()];
 | 
					        double[] extraParams = new double[fn.extraParamsLength()];
 | 
				
			||||||
        for( int i=0; i<extraParams.length; i++ ){
 | 
					        for (int i = 0; i < extraParams.length; i++) {
 | 
				
			||||||
            extraParams[i] = fn.extraParams(i);
 | 
					            extraParams[i] = fn.extraParams(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        long[] extraInteger = new long[fn.extraIntegerLength()];
 | 
					        long[] extraInteger = new long[fn.extraIntegerLength()];
 | 
				
			||||||
        for( int i=0; i<extraInteger.length; i++ ){
 | 
					        for (int i = 0; i < extraInteger.length; i++) {
 | 
				
			||||||
            extraInteger[i] = fn.extraInteger(i);
 | 
					            extraInteger[i] = fn.extraInteger(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        boolean[] extraBools = new boolean[fn.extraBoolsLength()];
 | 
					        boolean[] extraBools = new boolean[fn.extraBoolsLength()];
 | 
				
			||||||
        for( int i=0; i<extraBools.length; i++ ){
 | 
					        for (int i = 0; i < extraBools.length; i++) {
 | 
				
			||||||
            extraBools[i] = fn.extraBools(i);
 | 
					            extraBools[i] = fn.extraBools(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        int[] dimensions = new int[fn.dimensionsLength()];
 | 
					        int[] dimensions = new int[fn.dimensionsLength()];
 | 
				
			||||||
        for( int i=0; i<dimensions.length; i++ ){
 | 
					        for (int i = 0; i < dimensions.length; i++) {
 | 
				
			||||||
            dimensions[i] = fn.dimensions(i);
 | 
					            dimensions[i] = fn.dimensions(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        FlatArray fa = fn.scalar();
 | 
					        FlatArray fa = fn.scalar();
 | 
				
			||||||
        INDArray scalar = null;
 | 
					        INDArray scalar = null;
 | 
				
			||||||
        if(fa != null){
 | 
					        if (fa != null) {
 | 
				
			||||||
            scalar = Nd4j.createFromFlatArray(fa);
 | 
					            scalar = Nd4j.createFromFlatArray(fa);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        FlatProperties[] flatProperties = new FlatProperties[fn.propertiesLength()];
 | 
					        FlatProperties[] flatProperties = new FlatProperties[fn.propertiesLength()];
 | 
				
			||||||
        for( int i=0; i<flatProperties.length; i++ ){
 | 
					        for (int i = 0; i < flatProperties.length; i++) {
 | 
				
			||||||
            flatProperties[i] = fn.properties(i);
 | 
					            flatProperties[i] = fn.properties(i);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        Map<String,Object> props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));
 | 
					        Map<String, Object> props = FlatBuffersMapper
 | 
				
			||||||
 | 
					                .mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if (opType == Op.Type.CUSTOM || opType == Type.LOGIC) {
 | 
				
			||||||
        if(opType == Op.Type.CUSTOM) {
 | 
					 | 
				
			||||||
            String opName = fn.opName();
 | 
					            String opName = fn.opName();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            DifferentialFunction op;
 | 
				
			||||||
            Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);
 | 
					            Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum);
 | 
					            Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            DifferentialFunction op;
 | 
					 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
                op = (DifferentialFunction) c.newInstance();
 | 
					                op = (DifferentialFunction) c.newInstance();
 | 
				
			||||||
            } catch (IllegalAccessException | InstantiationException e) {
 | 
					            } catch (IllegalAccessException | InstantiationException e) {
 | 
				
			||||||
                throw new RuntimeException("Error creating differential function instance of type " + c);
 | 
					                throw new RuntimeException("Error creating differential function instance of type " + c);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            op.setOwnName(name);
 | 
					            op.setOwnName(name);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            //Set input SDVariables:
 | 
					            //Set input SDVariables:
 | 
				
			||||||
@ -390,7 +395,7 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
            op.setPropertiesForFunction(props);
 | 
					            op.setPropertiesForFunction(props);
 | 
				
			||||||
            return op;
 | 
					            return op;
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
            Class<?> c = LegacyOpMapper.getLegacyOpClassForId(opType, (int)opNum);
 | 
					            Class<?> c = LegacyOpMapper.getLegacyOpClassForId(opType, (int) opNum);
 | 
				
			||||||
            Op op;
 | 
					            Op op;
 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
                op = (Op) c.newInstance();
 | 
					                op = (Op) c.newInstance();
 | 
				
			||||||
@ -398,7 +403,7 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
                throw new RuntimeException("Error creating differential function (Op) instance of type " + c);
 | 
					                throw new RuntimeException("Error creating differential function (Op) instance of type " + c);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if(extraParams.length > 0) {
 | 
					            if (extraParams.length > 0) {
 | 
				
			||||||
                //Assume that extraParams length 0 means extraArgs was originally null, NOT originally length 0
 | 
					                //Assume that extraParams length 0 means extraArgs was originally null, NOT originally length 0
 | 
				
			||||||
                Object[] extraParamsObj = new Object[extraParams.length];
 | 
					                Object[] extraParamsObj = new Object[extraParams.length];
 | 
				
			||||||
                for (int i = 0; i < extraParams.length; i++) {
 | 
					                for (int i = 0; i < extraParams.length; i++) {
 | 
				
			||||||
@ -406,16 +411,18 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
                }
 | 
					                }
 | 
				
			||||||
                op.setExtraArgs(extraParamsObj);
 | 
					                op.setExtraArgs(extraParamsObj);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if(opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL){
 | 
					            if (opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL) {
 | 
				
			||||||
                ScalarOp sOp = (ScalarOp)op;
 | 
					                ScalarOp sOp = (ScalarOp) op;
 | 
				
			||||||
                sOp.setScalar(scalar);
 | 
					                sOp.setScalar(scalar);
 | 
				
			||||||
            } else if(opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE
 | 
					            } else if (opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS
 | 
				
			||||||
                    || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG || opType == Op.Type.REDUCE_SAME) {
 | 
					                    || opType == Op.Type.VARIANCE
 | 
				
			||||||
 | 
					                    || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG
 | 
				
			||||||
 | 
					                    || opType == Op.Type.REDUCE_SAME) {
 | 
				
			||||||
                val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations
 | 
					                val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations
 | 
				
			||||||
                ba.setDimensions(dimensions);
 | 
					                ba.setDimensions(dimensions);
 | 
				
			||||||
                ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
 | 
					                ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
 | 
				
			||||||
            } else if(opType == Op.Type.INDEXREDUCE){
 | 
					            } else if (opType == Op.Type.INDEXREDUCE) {
 | 
				
			||||||
                BaseIndexAccumulation bia = (BaseIndexAccumulation)op;
 | 
					                BaseIndexAccumulation bia = (BaseIndexAccumulation) op;
 | 
				
			||||||
                bia.setDimensions(dimensions);
 | 
					                bia.setDimensions(dimensions);
 | 
				
			||||||
                bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
 | 
					                bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@ -428,8 +435,8 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
            TRANSFORM_SAME - Abs, Ceil, etc
 | 
					            TRANSFORM_SAME - Abs, Ceil, etc
 | 
				
			||||||
             */
 | 
					             */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ((DifferentialFunction)op).setPropertiesForFunction(props);
 | 
					            ((DifferentialFunction) op).setPropertiesForFunction(props);
 | 
				
			||||||
            return (DifferentialFunction)op;
 | 
					            return (DifferentialFunction) op;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -438,11 +445,11 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
    private static final long[] EMPTY_LONG = new long[0];
 | 
					    private static final long[] EMPTY_LONG = new long[0];
 | 
				
			||||||
    private static final double[] EMPTY_DOUBLE = new double[0];
 | 
					    private static final double[] EMPTY_DOUBLE = new double[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map<String,Object> fnProps){
 | 
					    public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map<String, Object> fnProps) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        int[] outIdxs = new int[fnProps.size()];
 | 
					        int[] outIdxs = new int[fnProps.size()];
 | 
				
			||||||
        int count = 0;
 | 
					        int count = 0;
 | 
				
			||||||
        for(Map.Entry<String,Object> e : fnProps.entrySet()){
 | 
					        for (Map.Entry<String, Object> e : fnProps.entrySet()) {
 | 
				
			||||||
            //Possible types here: primitives (as Number objects), primitive arrays, Strings, String arrays, multi-dimensional string/primitives
 | 
					            //Possible types here: primitives (as Number objects), primitive arrays, Strings, String arrays, multi-dimensional string/primitives
 | 
				
			||||||
            Object v = e.getValue();
 | 
					            Object v = e.getValue();
 | 
				
			||||||
            int iname = fbb.createString(e.getKey());
 | 
					            int iname = fbb.createString(e.getKey());
 | 
				
			||||||
@ -455,13 +462,11 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
            int[] sIdx = null;
 | 
					            int[] sIdx = null;
 | 
				
			||||||
            int[] shape = null;
 | 
					            int[] shape = null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (v == null) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
            if(v == null) {
 | 
					 | 
				
			||||||
                //No op
 | 
					                //No op
 | 
				
			||||||
            } else if(v instanceof Boolean){
 | 
					            } else if (v instanceof Boolean) {
 | 
				
			||||||
                b = new boolean[]{(Boolean)v};
 | 
					                b = new boolean[]{(Boolean) v};
 | 
				
			||||||
            } else if(v instanceof Number) {
 | 
					            } else if (v instanceof Number) {
 | 
				
			||||||
                if (v instanceof Double) {
 | 
					                if (v instanceof Double) {
 | 
				
			||||||
                    d = new double[]{(Double) v};
 | 
					                    d = new double[]{(Double) v};
 | 
				
			||||||
                } else if (v instanceof Integer) {
 | 
					                } else if (v instanceof Integer) {
 | 
				
			||||||
@ -469,39 +474,41 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
                } else if (v instanceof Long) {
 | 
					                } else if (v instanceof Long) {
 | 
				
			||||||
                    l = new long[]{(Long) v};
 | 
					                    l = new long[]{(Long) v};
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
                    throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
 | 
					                    throw new UnsupportedOperationException(
 | 
				
			||||||
 | 
					                            "Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            } else if(v instanceof String) {
 | 
					            } else if (v instanceof String) {
 | 
				
			||||||
                String str = (String) v;
 | 
					                String str = (String) v;
 | 
				
			||||||
                int strOffset = fbb.createString(str);
 | 
					                int strOffset = fbb.createString(str);
 | 
				
			||||||
                sIdx = new int[]{strOffset};
 | 
					                sIdx = new int[]{strOffset};
 | 
				
			||||||
            } else if(v instanceof org.nd4j.linalg.api.buffer.DataType ) {
 | 
					            } else if (v instanceof org.nd4j.linalg.api.buffer.DataType) {
 | 
				
			||||||
                String str = v.toString();
 | 
					                String str = v.toString();
 | 
				
			||||||
                int strOffset = fbb.createString(str);
 | 
					                int strOffset = fbb.createString(str);
 | 
				
			||||||
                sIdx = new int[]{strOffset};
 | 
					                sIdx = new int[]{strOffset};
 | 
				
			||||||
            } else if(v instanceof Enum){
 | 
					            } else if (v instanceof Enum) {
 | 
				
			||||||
                String str = v.toString();
 | 
					                String str = v.toString();
 | 
				
			||||||
                int strOffset = fbb.createString(str);
 | 
					                int strOffset = fbb.createString(str);
 | 
				
			||||||
                sIdx = new int[]{strOffset};
 | 
					                sIdx = new int[]{strOffset};
 | 
				
			||||||
            } else if(v instanceof INDArray){
 | 
					            } else if (v instanceof INDArray) {
 | 
				
			||||||
                INDArray arr = (INDArray)v;
 | 
					                INDArray arr = (INDArray) v;
 | 
				
			||||||
                aIdx = new int[]{arr.toFlatArray(fbb)};
 | 
					                aIdx = new int[]{arr.toFlatArray(fbb)};
 | 
				
			||||||
            } else if(v.getClass().isArray()){
 | 
					            } else if (v.getClass().isArray()) {
 | 
				
			||||||
                if(v.getClass().getComponentType().isPrimitive()){
 | 
					                if (v.getClass().getComponentType().isPrimitive()) {
 | 
				
			||||||
                    if(v instanceof boolean[]) {
 | 
					                    if (v instanceof boolean[]) {
 | 
				
			||||||
                        b = (boolean[])v;
 | 
					                        b = (boolean[]) v;
 | 
				
			||||||
                        shape = new int[]{b.length};
 | 
					                        shape = new int[]{b.length};
 | 
				
			||||||
                    } else if(v instanceof double[]){
 | 
					                    } else if (v instanceof double[]) {
 | 
				
			||||||
                        d = (double[])v;
 | 
					                        d = (double[]) v;
 | 
				
			||||||
                        shape = new int[]{d.length};
 | 
					                        shape = new int[]{d.length};
 | 
				
			||||||
                    } else if(v instanceof int[]){
 | 
					                    } else if (v instanceof int[]) {
 | 
				
			||||||
                        i = (int[])v;
 | 
					                        i = (int[]) v;
 | 
				
			||||||
                        shape = new int[]{i.length};
 | 
					                        shape = new int[]{i.length};
 | 
				
			||||||
                    } else if(v instanceof long[]){
 | 
					                    } else if (v instanceof long[]) {
 | 
				
			||||||
                        l = (long[])v;
 | 
					                        l = (long[]) v;
 | 
				
			||||||
                        shape = new int[]{l.length};
 | 
					                        shape = new int[]{l.length};
 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
 | 
					                        throw new UnsupportedOperationException(
 | 
				
			||||||
 | 
					                                "Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if (v instanceof String[]) {
 | 
					                } else if (v instanceof String[]) {
 | 
				
			||||||
                    //String[]
 | 
					                    //String[]
 | 
				
			||||||
@ -511,33 +518,35 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
                        sIdx[j] = fbb.createString(strArr[j]);
 | 
					                        sIdx[j] = fbb.createString(strArr[j]);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    shape = new int[]{strArr.length};
 | 
					                    shape = new int[]{strArr.length};
 | 
				
			||||||
                } else if (v instanceof INDArray[]){
 | 
					                } else if (v instanceof INDArray[]) {
 | 
				
			||||||
                    INDArray[] arrArr = (INDArray[])v;
 | 
					                    INDArray[] arrArr = (INDArray[]) v;
 | 
				
			||||||
                    aIdx = new int[arrArr.length];
 | 
					                    aIdx = new int[arrArr.length];
 | 
				
			||||||
                    for( int j=0; j<arrArr.length; j++){
 | 
					                    for (int j = 0; j < arrArr.length; j++) {
 | 
				
			||||||
                        aIdx[j] = arrArr[j].toFlatArray(fbb);
 | 
					                        aIdx[j] = arrArr[j].toFlatArray(fbb);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(v.getClass().getComponentType().isArray()){
 | 
					                } else if (v.getClass().getComponentType().isArray()) {
 | 
				
			||||||
                    shape = ArrayUtil.arrayShape(v, true);
 | 
					                    shape = ArrayUtil.arrayShape(v, true);
 | 
				
			||||||
                    //Multi-dimensional array
 | 
					                    //Multi-dimensional array
 | 
				
			||||||
                    if(v instanceof boolean[][]) {
 | 
					                    if (v instanceof boolean[][]) {
 | 
				
			||||||
                        b = ArrayUtil.flatten((boolean[][]) v);
 | 
					                        b = ArrayUtil.flatten((boolean[][]) v);
 | 
				
			||||||
                    } else if(v instanceof boolean[][][]){
 | 
					                    } else if (v instanceof boolean[][][]) {
 | 
				
			||||||
                        b = ArrayUtil.flatten((boolean[][][]) v);
 | 
					                        b = ArrayUtil.flatten((boolean[][][]) v);
 | 
				
			||||||
                    } else if(v instanceof double[][]){
 | 
					                    } else if (v instanceof double[][]) {
 | 
				
			||||||
                        d = ArrayUtil.flatten((double[][]) v);
 | 
					                        d = ArrayUtil.flatten((double[][]) v);
 | 
				
			||||||
                    } else if(v instanceof double[][][]){
 | 
					                    } else if (v instanceof double[][][]) {
 | 
				
			||||||
                        d = ArrayUtil.flatten((double[][][]) v);
 | 
					                        d = ArrayUtil.flatten((double[][][]) v);
 | 
				
			||||||
                    } else if(v instanceof int[][]){
 | 
					                    } else if (v instanceof int[][]) {
 | 
				
			||||||
                        i = ArrayUtil.flatten((int[][])v);
 | 
					                        i = ArrayUtil.flatten((int[][]) v);
 | 
				
			||||||
                    } else if(v instanceof int[][][]){
 | 
					                    } else if (v instanceof int[][][]) {
 | 
				
			||||||
                        i = ArrayUtil.flatten((int[][][])v);
 | 
					                        i = ArrayUtil.flatten((int[][][]) v);
 | 
				
			||||||
                    } else if(v instanceof long[][]){
 | 
					                    } else if (v instanceof long[][]) {
 | 
				
			||||||
                        l = ArrayUtil.flatten((long[][])v);
 | 
					                        l = ArrayUtil.flatten((long[][]) v);
 | 
				
			||||||
                    } else if(v instanceof long[][][]){
 | 
					                    } else if (v instanceof long[][][]) {
 | 
				
			||||||
                        l = ArrayUtil.flatten((long[][][])v);
 | 
					                        l = ArrayUtil.flatten((long[][][]) v);
 | 
				
			||||||
                    } else {
 | 
					                    } else {
 | 
				
			||||||
                        throw new UnsupportedOperationException("Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v.getClass());
 | 
					                        throw new UnsupportedOperationException(
 | 
				
			||||||
 | 
					                                "Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v
 | 
				
			||||||
 | 
					                                        .getClass());
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
@ -550,21 +559,22 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
            int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT);
 | 
					            int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT);
 | 
				
			||||||
            int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT);
 | 
					            int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            outIdxs[count++] = FlatProperties.createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape);
 | 
					            outIdxs[count++] = FlatProperties
 | 
				
			||||||
 | 
					                    .createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        return outIdxs;
 | 
					        return outIdxs;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static Map<String,Object> mapFlatPropertiesToFunctionProperties(Iterable<FlatProperties> list){
 | 
					    public static Map<String, Object> mapFlatPropertiesToFunctionProperties(Iterable<FlatProperties> list) {
 | 
				
			||||||
        Map<String,Object> out = new HashMap<>();
 | 
					        Map<String, Object> out = new HashMap<>();
 | 
				
			||||||
        for(FlatProperties p : list){
 | 
					        for (FlatProperties p : list) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            String name = p.name();
 | 
					            String name = p.name();
 | 
				
			||||||
            //Work out type:
 | 
					            //Work out type:
 | 
				
			||||||
            if(p.shapeLength() > 0){
 | 
					            if (p.shapeLength() > 0) {
 | 
				
			||||||
                //Array type
 | 
					                //Array type
 | 
				
			||||||
                int[] shape = new int[p.shapeLength()];
 | 
					                int[] shape = new int[p.shapeLength()];
 | 
				
			||||||
                for( int i=0; i<shape.length; i++ ){
 | 
					                for (int i = 0; i < shape.length; i++) {
 | 
				
			||||||
                    shape[i] = p.shape(i);
 | 
					                    shape[i] = p.shape(i);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
//                if(shape.length != 1){
 | 
					//                if(shape.length != 1){
 | 
				
			||||||
@ -572,96 +582,96 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
//                    throw new IllegalStateException("Multi-dimensional arrays not yet implemented");
 | 
					//                    throw new IllegalStateException("Multi-dimensional arrays not yet implemented");
 | 
				
			||||||
//                }
 | 
					//                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if(p.iLength() > 0){
 | 
					                if (p.iLength() > 0) {
 | 
				
			||||||
                    int[] iArr = new int[p.iLength()];
 | 
					                    int[] iArr = new int[p.iLength()];
 | 
				
			||||||
                    for( int i=0; i<iArr.length; i++ ){
 | 
					                    for (int i = 0; i < iArr.length; i++) {
 | 
				
			||||||
                        iArr[i] = p.i(i);
 | 
					                        iArr[i] = p.i(i);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, iArr);
 | 
					                        out.put(name, iArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeInt(iArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeInt(iArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeInt(iArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeInt(iArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(p.dLength() > 0){
 | 
					                } else if (p.dLength() > 0) {
 | 
				
			||||||
                    double[] dArr = new double[p.dLength()];
 | 
					                    double[] dArr = new double[p.dLength()];
 | 
				
			||||||
                    for( int i=0; i<dArr.length; i++ ){
 | 
					                    for (int i = 0; i < dArr.length; i++) {
 | 
				
			||||||
                        dArr[i] = p.d(i);
 | 
					                        dArr[i] = p.d(i);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, dArr);
 | 
					                        out.put(name, dArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeDouble(dArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeDouble(dArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeDouble(dArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeDouble(dArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(p.lLength() > 0) {
 | 
					                } else if (p.lLength() > 0) {
 | 
				
			||||||
                    long[] lArr = new long[p.lLength()];
 | 
					                    long[] lArr = new long[p.lLength()];
 | 
				
			||||||
                    for (int i = 0; i < lArr.length; i++) {
 | 
					                    for (int i = 0; i < lArr.length; i++) {
 | 
				
			||||||
                        lArr[i] = p.l(i);
 | 
					                        lArr[i] = p.l(i);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, lArr);
 | 
					                        out.put(name, lArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(p.bLength() > 0){
 | 
					                } else if (p.bLength() > 0) {
 | 
				
			||||||
                    boolean[] bArr = new boolean[p.bLength()];
 | 
					                    boolean[] bArr = new boolean[p.bLength()];
 | 
				
			||||||
                    for( int i=0; i<bArr.length; i++ ){
 | 
					                    for (int i = 0; i < bArr.length; i++) {
 | 
				
			||||||
                        bArr[i] = p.b(i);
 | 
					                        bArr[i] = p.b(i);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, bArr);
 | 
					                        out.put(name, bArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeBoolean(bArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeBoolean(bArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeBoolean(bArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeBoolean(bArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(p.sLength() > 0){
 | 
					                } else if (p.sLength() > 0) {
 | 
				
			||||||
                    String[] sArr = new String[p.sLength()];
 | 
					                    String[] sArr = new String[p.sLength()];
 | 
				
			||||||
                    for( int i=0; i<sArr.length; i++ ){
 | 
					                    for (int i = 0; i < sArr.length; i++) {
 | 
				
			||||||
                        sArr[i] = p.s(i);
 | 
					                        sArr[i] = p.s(i);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, sArr);
 | 
					                        out.put(name, sArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeObject(sArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeObject(sArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeObject(sArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeObject(sArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } else if(p.aLength() > 0){
 | 
					                } else if (p.aLength() > 0) {
 | 
				
			||||||
                    INDArray[] iArr = new INDArray[p.aLength()];
 | 
					                    INDArray[] iArr = new INDArray[p.aLength()];
 | 
				
			||||||
                    for( int i=0; i<iArr.length; i++ ){
 | 
					                    for (int i = 0; i < iArr.length; i++) {
 | 
				
			||||||
                        FlatArray fa = p.a(0);
 | 
					                        FlatArray fa = p.a(0);
 | 
				
			||||||
                        iArr[i] = Nd4j.createFromFlatArray(fa);
 | 
					                        iArr[i] = Nd4j.createFromFlatArray(fa);
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                    if(shape.length == 0 || shape.length == 1) {
 | 
					                    if (shape.length == 0 || shape.length == 1) {
 | 
				
			||||||
                        out.put(name, iArr);
 | 
					                        out.put(name, iArr);
 | 
				
			||||||
                    } else if(shape.length == 2){
 | 
					                    } else if (shape.length == 2) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeObject(iArr, shape[0], shape[1]));
 | 
					                        out.put(name, ArrayUtil.reshapeObject(iArr, shape[0], shape[1]));
 | 
				
			||||||
                    } else if(shape.length == 3){
 | 
					                    } else if (shape.length == 3) {
 | 
				
			||||||
                        out.put(name, ArrayUtil.reshapeObject(iArr, shape[0], shape[1], shape[2]));
 | 
					                        out.put(name, ArrayUtil.reshapeObject(iArr, shape[0], shape[1], shape[2]));
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }  else {
 | 
					                } else {
 | 
				
			||||||
                    //null property case
 | 
					                    //null property case
 | 
				
			||||||
                    out.put(name, null);
 | 
					                    out.put(name, null);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                //non-array primitive, String or INDArray
 | 
					                //non-array primitive, String or INDArray
 | 
				
			||||||
                if(p.bLength() > 0) {
 | 
					                if (p.bLength() > 0) {
 | 
				
			||||||
                    out.put(name, p.b(0));
 | 
					                    out.put(name, p.b(0));
 | 
				
			||||||
                } else if(p.iLength() > 0){
 | 
					                } else if (p.iLength() > 0) {
 | 
				
			||||||
                    out.put(name, p.i(0));
 | 
					                    out.put(name, p.i(0));
 | 
				
			||||||
                } else if(p.lLength() > 0){
 | 
					                } else if (p.lLength() > 0) {
 | 
				
			||||||
                    out.put(name, p.l(0));
 | 
					                    out.put(name, p.l(0));
 | 
				
			||||||
                } else if(p.dLength() > 0){
 | 
					                } else if (p.dLength() > 0) {
 | 
				
			||||||
                    out.put(name, p.d(0));
 | 
					                    out.put(name, p.d(0));
 | 
				
			||||||
                } else if(p.sLength() > 0){
 | 
					                } else if (p.sLength() > 0) {
 | 
				
			||||||
                    out.put(name, p.s(0));
 | 
					                    out.put(name, p.s(0));
 | 
				
			||||||
                } else if(p.aLength() > 0){
 | 
					                } else if (p.aLength() > 0) {
 | 
				
			||||||
                    FlatArray fa = p.a(0);
 | 
					                    FlatArray fa = p.a(0);
 | 
				
			||||||
                    out.put(name, Nd4j.createFromFlatArray(fa));
 | 
					                    out.put(name, Nd4j.createFromFlatArray(fa));
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
@ -673,8 +683,8 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
        return out;
 | 
					        return out;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static byte toVarType(VariableType variableType){
 | 
					    public static byte toVarType(VariableType variableType) {
 | 
				
			||||||
        switch (variableType){
 | 
					        switch (variableType) {
 | 
				
			||||||
            case VARIABLE:
 | 
					            case VARIABLE:
 | 
				
			||||||
                return VarType.VARIABLE;
 | 
					                return VarType.VARIABLE;
 | 
				
			||||||
            case CONSTANT:
 | 
					            case CONSTANT:
 | 
				
			||||||
@ -688,8 +698,8 @@ public class FlatBuffersMapper {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static VariableType fromVarType(byte varType){
 | 
					    public static VariableType fromVarType(byte varType) {
 | 
				
			||||||
        switch (varType){
 | 
					        switch (varType) {
 | 
				
			||||||
            case VarType.VARIABLE:
 | 
					            case VarType.VARIABLE:
 | 
				
			||||||
                return VariableType.VARIABLE;
 | 
					                return VariableType.VARIABLE;
 | 
				
			||||||
            case VarType.CONSTANT:
 | 
					            case VarType.CONSTANT:
 | 
				
			||||||
 | 
				
			|||||||
@ -126,12 +126,7 @@ public class LegacyOpMapper {
 | 
				
			|||||||
            case CONDITIONAL:
 | 
					            case CONDITIONAL:
 | 
				
			||||||
            case LOOP:
 | 
					            case LOOP:
 | 
				
			||||||
            case LOOP_COND:
 | 
					            case LOOP_COND:
 | 
				
			||||||
            case IF:
 | 
					 | 
				
			||||||
            case RETURN:
 | 
					            case RETURN:
 | 
				
			||||||
            case ENTER:
 | 
					 | 
				
			||||||
            case EXIT:
 | 
					 | 
				
			||||||
            case NEXT_ITERATION:
 | 
					 | 
				
			||||||
            case MERGE:
 | 
					 | 
				
			||||||
            default:
 | 
					            default:
 | 
				
			||||||
                throw new UnsupportedOperationException("Unable to map op " + opNum + " of type " + opType);
 | 
					                throw new UnsupportedOperationException("Unable to map op " + opNum + " of type " + opType);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
				
			|||||||
@ -25,6 +25,11 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
 | 
				
			|||||||
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
 | 
					import org.nd4j.imports.descriptors.onnx.OpDescriptor;
 | 
				
			||||||
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
 | 
					import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.*;
 | 
					import org.nd4j.linalg.api.ops.*;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
 | 
					import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
 | 
				
			||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
					import org.nd4j.linalg.exception.ND4JIllegalStateException;
 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
@ -331,13 +336,27 @@ public class DifferentialFunctionClassHolder {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public Class<?> customOpClassForHashAndName(long customOpHash, String name){
 | 
					    public Class<?> customOpClassForHashAndName(long customOpHash, String name){
 | 
				
			||||||
        if(customOpHashToClasses.containsKey(customOpHash)){
 | 
					        switch (name) {
 | 
				
			||||||
            return customOpHashToClasses.get(customOpHash).get(name);
 | 
					            case Enter.OP_NAME:
 | 
				
			||||||
        } else if(customOpHashToClass.containsKey(customOpHash)){
 | 
					                return Enter.class;
 | 
				
			||||||
            return customOpHashToClass.get(customOpHash);
 | 
					            case Exit.OP_NAME:
 | 
				
			||||||
        } else {
 | 
					                return Exit.class;
 | 
				
			||||||
            throw new IllegalStateException("No op known for hash: " + customOpHash);
 | 
					            case NextIteration.OP_NAME:
 | 
				
			||||||
 | 
					                return NextIteration.class;
 | 
				
			||||||
 | 
					            case Merge.OP_NAME:
 | 
				
			||||||
 | 
					                return Merge.class;
 | 
				
			||||||
 | 
					            case Switch.OP_NAME:
 | 
				
			||||||
 | 
					                return Switch.class;
 | 
				
			||||||
 | 
					            default:
 | 
				
			||||||
 | 
					                if(customOpHashToClasses.containsKey(customOpHash)){
 | 
				
			||||||
 | 
					                    return customOpHashToClasses.get(customOpHash).get(name);
 | 
				
			||||||
 | 
					                } else if(customOpHashToClass.containsKey(customOpHash)){
 | 
				
			||||||
 | 
					                    return customOpHashToClass.get(customOpHash);
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    throw new IllegalStateException("No op known for hash: " + customOpHash);
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public static DifferentialFunctionClassHolder getInstance() {
 | 
					    public static DifferentialFunctionClassHolder getInstance() {
 | 
				
			||||||
 | 
				
			|||||||
@ -69,14 +69,10 @@ public interface Op {
 | 
				
			|||||||
        CONDITIONAL,
 | 
					        CONDITIONAL,
 | 
				
			||||||
        LOOP,
 | 
					        LOOP,
 | 
				
			||||||
        LOOP_COND,
 | 
					        LOOP_COND,
 | 
				
			||||||
        IF,
 | 
					 | 
				
			||||||
        RETURN,
 | 
					        RETURN,
 | 
				
			||||||
        ENTER,
 | 
					 | 
				
			||||||
        EXIT,
 | 
					 | 
				
			||||||
        NEXT_ITERATION,
 | 
					 | 
				
			||||||
        RANDOM,
 | 
					        RANDOM,
 | 
				
			||||||
        MERGE,
 | 
					 | 
				
			||||||
        SUMMARYSTATS,
 | 
					        SUMMARYSTATS,
 | 
				
			||||||
 | 
					        LOGIC
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
 | 
				
			|||||||
@ -17,11 +17,13 @@
 | 
				
			|||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
					package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import lombok.Data;
 | 
					import lombok.Data;
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
					import org.nd4j.autodiff.samediff.SDVariable;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
					import org.nd4j.autodiff.samediff.SameDiff;
 | 
				
			||||||
import org.nd4j.base.Preconditions;
 | 
					import org.nd4j.base.Preconditions;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.Op;
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.tensorflow.framework.AttrValue;
 | 
					import org.tensorflow.framework.AttrValue;
 | 
				
			||||||
import org.tensorflow.framework.GraphDef;
 | 
					import org.tensorflow.framework.GraphDef;
 | 
				
			||||||
@ -32,13 +34,38 @@ import java.util.List;
 | 
				
			|||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Data
 | 
					@Data
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
public class Enter extends BaseCompatOp {
 | 
					public class Enter extends BaseCompatOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    protected boolean isConstant;
 | 
					    protected boolean isConstant;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Enter(SameDiff sameDiff, SDVariable[] inputs){
 | 
				
			||||||
 | 
					        super(sameDiff, inputs);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Enter(SameDiff sameDiff, String frameName, SDVariable input){
 | 
				
			||||||
 | 
					        super(sameDiff, new SDVariable[]{input});
 | 
				
			||||||
 | 
					        this.frameName = frameName;
 | 
				
			||||||
 | 
					        isConstant = input.isConstant();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Enter(SameDiff sameDiff, String frameName, SDVariable input, boolean isConstant){
 | 
				
			||||||
 | 
					        super(sameDiff, new SDVariable[]{input});
 | 
				
			||||||
 | 
					        this.frameName = frameName;
 | 
				
			||||||
 | 
					        this.isConstant = isConstant;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * WARNING: do not change without changing serialization methods
 | 
				
			||||||
 | 
					     * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
 | 
				
			||||||
 | 
					     *  and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public static final String OP_NAME = "enter";
 | 
				
			||||||
 | 
					    public static final int OP_NUM = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
        return "enter";
 | 
					        return OP_NAME;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
@ -62,7 +89,7 @@ public class Enter extends BaseCompatOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Op.Type opType() {
 | 
					    public Op.Type opType() {
 | 
				
			||||||
        return Op.Type.ENTER;
 | 
					        return Type.LOGIC;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
					package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
import lombok.NonNull;
 | 
					import lombok.NonNull;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
					import org.nd4j.autodiff.samediff.SDVariable;
 | 
				
			||||||
@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions;
 | 
				
			|||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
					import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.Op;
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.tensorflow.framework.AttrValue;
 | 
					import org.tensorflow.framework.AttrValue;
 | 
				
			||||||
import org.tensorflow.framework.GraphDef;
 | 
					import org.tensorflow.framework.GraphDef;
 | 
				
			||||||
@ -34,10 +36,24 @@ import java.util.Collections;
 | 
				
			|||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
public class Exit extends BaseCompatOp {
 | 
					public class Exit extends BaseCompatOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Exit(SameDiff sameDiff, SDVariable x) {
 | 
				
			||||||
 | 
					        super(sameDiff, new SDVariable[]{x});
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * WARNING: do not change without changing serialization methods
 | 
				
			||||||
 | 
					     * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
 | 
				
			||||||
 | 
					     *  and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public static final String OP_NAME = "exit";
 | 
				
			||||||
 | 
					    public static final int OP_NUM = 90;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
        return "exit";
 | 
					        return OP_NAME;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
@ -61,7 +77,7 @@ public class Exit extends BaseCompatOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Op.Type opType() {
 | 
					    public Op.Type opType() {
 | 
				
			||||||
        return Op.Type.EXIT;
 | 
					        return Type.LOGIC;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
 | 
				
			|||||||
import org.nd4j.base.Preconditions;
 | 
					import org.nd4j.base.Preconditions;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.Op;
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.tensorflow.framework.AttrValue;
 | 
					import org.tensorflow.framework.AttrValue;
 | 
				
			||||||
import org.tensorflow.framework.GraphDef;
 | 
					import org.tensorflow.framework.GraphDef;
 | 
				
			||||||
@ -41,9 +42,21 @@ public class Merge extends BaseCompatOp {
 | 
				
			|||||||
        
 | 
					        
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * WARNING: do not change without changing serialization methods
 | 
				
			||||||
 | 
					     * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
 | 
				
			||||||
 | 
					     *  and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public static final String OP_NAME = "merge";
 | 
				
			||||||
 | 
					    public static final int OP_NUM = 60;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public Merge(SameDiff sd, SDVariable a, SDVariable b){
 | 
				
			||||||
 | 
					        this(sd, new SDVariable[]{a, b});
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
        return "merge";
 | 
					        return OP_NAME;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
@ -72,7 +85,7 @@ public class Merge extends BaseCompatOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Op.Type opType() {
 | 
					    public Op.Type opType() {
 | 
				
			||||||
        return Op.Type.MERGE;
 | 
					        return Type.LOGIC;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -16,11 +16,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
					package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import lombok.NoArgsConstructor;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
					import org.nd4j.autodiff.samediff.SDVariable;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
					import org.nd4j.autodiff.samediff.SameDiff;
 | 
				
			||||||
import org.nd4j.base.Preconditions;
 | 
					import org.nd4j.base.Preconditions;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.Op;
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.tensorflow.framework.AttrValue;
 | 
					import org.tensorflow.framework.AttrValue;
 | 
				
			||||||
import org.tensorflow.framework.GraphDef;
 | 
					import org.tensorflow.framework.GraphDef;
 | 
				
			||||||
@ -31,10 +33,24 @@ import java.util.Collections;
 | 
				
			|||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@NoArgsConstructor
 | 
				
			||||||
public class NextIteration extends BaseCompatOp {
 | 
					public class NextIteration extends BaseCompatOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public NextIteration(SameDiff sameDiff, SDVariable x) {
 | 
				
			||||||
 | 
					        super(sameDiff, new SDVariable[]{x});
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * WARNING: do not change without changing serialization methods
 | 
				
			||||||
 | 
					     * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
 | 
				
			||||||
 | 
					     *  and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public static final String OP_NAME = "next_iteration";
 | 
				
			||||||
 | 
					    public static final int OP_NUM = 80;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
        return "next_iteration";
 | 
					        return OP_NAME;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
@ -58,7 +74,7 @@ public class NextIteration extends BaseCompatOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Op.Type opType() {
 | 
					    public Op.Type opType() {
 | 
				
			||||||
        return Op.Type.NEXT_ITERATION;
 | 
					        return Type.LOGIC;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -16,12 +16,15 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
					package org.nd4j.linalg.api.ops.impl.controlflow.compat;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.google.common.collect.Lists;
 | 
				
			||||||
 | 
					import lombok.Getter;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
					import org.nd4j.autodiff.samediff.SDVariable;
 | 
				
			||||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
					import org.nd4j.autodiff.samediff.SameDiff;
 | 
				
			||||||
import org.nd4j.base.Preconditions;
 | 
					import org.nd4j.base.Preconditions;
 | 
				
			||||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
					import org.nd4j.linalg.api.buffer.DataType;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.Op;
 | 
					import org.nd4j.linalg.api.ops.Op;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.Op.Type;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.tensorflow.framework.AttrValue;
 | 
					import org.tensorflow.framework.AttrValue;
 | 
				
			||||||
import org.tensorflow.framework.GraphDef;
 | 
					import org.tensorflow.framework.GraphDef;
 | 
				
			||||||
@ -37,15 +40,27 @@ import java.util.Map;
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
public class Switch extends BaseCompatOp {
 | 
					public class Switch extends BaseCompatOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Getter
 | 
				
			||||||
 | 
					    private SDVariable predicate;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public Switch(SameDiff sameDiff, SDVariable input, SDVariable predicate){
 | 
					    public Switch(SameDiff sameDiff, SDVariable input, SDVariable predicate){
 | 
				
			||||||
        super(sameDiff, new SDVariable[]{input, predicate});
 | 
					        super(sameDiff, new SDVariable[]{input, predicate});
 | 
				
			||||||
 | 
					        this.predicate = predicate;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public Switch(){ }
 | 
					    public Switch(){ }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * WARNING: do not change without changing serialization methods
 | 
				
			||||||
 | 
					     * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
 | 
				
			||||||
 | 
					     *  and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    public static final String OP_NAME = "switch";
 | 
				
			||||||
 | 
					    public static final int OP_NUM = 30;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
        return "switch";
 | 
					        return OP_NAME;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
@ -72,7 +87,7 @@ public class Switch extends BaseCompatOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public Op.Type opType() {
 | 
					    public Op.Type opType() {
 | 
				
			||||||
        return Op.Type.IF;
 | 
					        return Type.LOGIC;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -39,6 +39,9 @@ import java.util.List;
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class LogSoftMax extends DynamicCustomOp {
 | 
					public class LogSoftMax extends DynamicCustomOp {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    private Integer dimension = null;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public LogSoftMax(SameDiff sameDiff, SDVariable i_v) {
 | 
					    public LogSoftMax(SameDiff sameDiff, SDVariable i_v) {
 | 
				
			||||||
        super(sameDiff, i_v);
 | 
					        super(sameDiff, i_v);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -54,6 +57,12 @@ public class LogSoftMax extends DynamicCustomOp {
 | 
				
			|||||||
        this(x, x);
 | 
					        this(x, x);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) {
 | 
				
			||||||
 | 
					        this(sameDiff, i_v);
 | 
				
			||||||
 | 
					        this.dimension = dimension;
 | 
				
			||||||
 | 
					        addIArgument(dimension);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public String opName() {
 | 
					    public String opName() {
 | 
				
			||||||
@ -66,8 +75,13 @@ public class LogSoftMax extends DynamicCustomOp {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public List<SDVariable> doDiff(List<SDVariable> i_v) {
 | 
					    public List<SDVariable> doDiff(List<SDVariable> i_v) {
 | 
				
			||||||
        SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0));
 | 
					        if(dimension == null) {
 | 
				
			||||||
        return Collections.singletonList(ret);
 | 
					            SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0));
 | 
				
			||||||
 | 
					            return Collections.singletonList(ret);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension);
 | 
				
			||||||
 | 
					            return Collections.singletonList(ret);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
 | 
				
			|||||||
@ -43,6 +43,11 @@ public class LogSoftMaxDerivative extends DynamicCustomOp {
 | 
				
			|||||||
        super(null, new INDArray[]{in, gradO}, new INDArray[]{out});
 | 
					        super(null, new INDArray[]{in, gradO}, new INDArray[]{out});
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public LogSoftMaxDerivative(SameDiff sameDiff, SDVariable arg, SDVariable wrt, int dimension) {
 | 
				
			||||||
 | 
					        this(sameDiff, arg, wrt);
 | 
				
			||||||
 | 
					        this.addIArgument(dimension);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * The opName of this operation
 | 
					     * The opName of this operation
 | 
				
			||||||
     *
 | 
					     *
 | 
				
			||||||
 | 
				
			|||||||
@ -129,4 +129,39 @@ public class NameScopeTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testNoNesting(){
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable a = SD.constant(4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope scope = SD.withNameScope("test");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable out = SD.argmax(a);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out.add(45);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1"));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testNoTesting2(){
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable a = SD.constant(4);
 | 
				
			||||||
 | 
					        SDVariable b = SD.constant(5).lt(4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        NameScope scope = SD.withNameScope("test");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable out = SD.f().switchOp(a, b)[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out.add(45);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scope.close();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertTrue("Var with name test/switch:1 exists", SD.variableMap().containsKey("test/switch:1"));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -16,12 +16,30 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package org.nd4j.autodiff.samediff;
 | 
					package org.nd4j.autodiff.samediff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import static org.junit.Assert.assertEquals;
 | 
				
			||||||
 | 
					import static org.junit.Assert.assertNotEquals;
 | 
				
			||||||
 | 
					import static org.junit.Assert.assertNotNull;
 | 
				
			||||||
 | 
					import static org.junit.Assert.assertNull;
 | 
				
			||||||
 | 
					import static org.junit.Assert.assertTrue;
 | 
				
			||||||
 | 
					import static org.junit.Assert.fail;
 | 
				
			||||||
 | 
					import static org.junit.Assume.assumeNotNull;
 | 
				
			||||||
 | 
					import static org.nd4j.linalg.indexing.NDArrayIndex.all;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.common.collect.Lists;
 | 
					import com.google.common.collect.Lists;
 | 
				
			||||||
 | 
					import com.google.common.collect.Maps;
 | 
				
			||||||
 | 
					import java.io.IOException;
 | 
				
			||||||
 | 
					import java.lang.reflect.Field;
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.HashMap;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Map;
 | 
				
			||||||
import lombok.extern.slf4j.Slf4j;
 | 
					import lombok.extern.slf4j.Slf4j;
 | 
				
			||||||
import lombok.val;
 | 
					import lombok.val;
 | 
				
			||||||
import org.junit.After;
 | 
					import org.junit.After;
 | 
				
			||||||
import org.junit.Before;
 | 
					import org.junit.Before;
 | 
				
			||||||
import org.junit.ClassRule;
 | 
					import org.junit.ClassRule;
 | 
				
			||||||
 | 
					import org.junit.Ignore;
 | 
				
			||||||
import org.junit.Test;
 | 
					import org.junit.Test;
 | 
				
			||||||
import org.junit.rules.TemporaryFolder;
 | 
					import org.junit.rules.TemporaryFolder;
 | 
				
			||||||
import org.nd4j.OpValidationSuite;
 | 
					import org.nd4j.OpValidationSuite;
 | 
				
			||||||
@ -43,7 +61,11 @@ import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
 | 
				
			|||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
 | 
				
			||||||
 | 
					import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
 | 
				
			||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
 | 
					import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
 | 
				
			||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
					import org.nd4j.linalg.api.shape.LongShapeDescriptor;
 | 
				
			||||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
 | 
					import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
 | 
				
			||||||
@ -53,9 +75,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
 | 
				
			|||||||
import org.nd4j.linalg.factory.Nd4j;
 | 
					import org.nd4j.linalg.factory.Nd4j;
 | 
				
			||||||
import org.nd4j.linalg.factory.Nd4jBackend;
 | 
					import org.nd4j.linalg.factory.Nd4jBackend;
 | 
				
			||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
					import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
				
			||||||
import org.nd4j.linalg.learning.GradientUpdater;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.learning.config.Adam;
 | 
					import org.nd4j.linalg.learning.config.Adam;
 | 
				
			||||||
import org.nd4j.linalg.learning.config.Nesterovs;
 | 
					 | 
				
			||||||
import org.nd4j.linalg.ops.transforms.Transforms;
 | 
					import org.nd4j.linalg.ops.transforms.Transforms;
 | 
				
			||||||
import org.nd4j.linalg.primitives.Pair;
 | 
					import org.nd4j.linalg.primitives.Pair;
 | 
				
			||||||
import org.nd4j.nativeblas.NativeOpsHolder;
 | 
					import org.nd4j.nativeblas.NativeOpsHolder;
 | 
				
			||||||
@ -63,29 +83,20 @@ import org.nd4j.weightinit.impl.OneInitScheme;
 | 
				
			|||||||
import org.nd4j.weightinit.impl.UniformInitScheme;
 | 
					import org.nd4j.weightinit.impl.UniformInitScheme;
 | 
				
			||||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
 | 
					import org.nd4j.weightinit.impl.ZeroInitScheme;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.io.BufferedOutputStream;
 | 
					 | 
				
			||||||
import java.io.File;
 | 
					 | 
				
			||||||
import java.io.FileOutputStream;
 | 
					 | 
				
			||||||
import java.lang.reflect.Field;
 | 
					 | 
				
			||||||
import java.util.*;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import static org.junit.Assert.*;
 | 
					 | 
				
			||||||
import static org.junit.Assume.assumeNotNull;
 | 
					 | 
				
			||||||
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Created by agibsonccc on 4/11/17.
 | 
					 * Created by agibsonccc on 4/11/17.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
@Slf4j
 | 
					@Slf4j
 | 
				
			||||||
public class SameDiffTests extends BaseNd4jTest {
 | 
					public class SameDiffTests extends BaseNd4jTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    private DataType initialType;
 | 
					    private DataType initialType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    public SameDiffTests(Nd4jBackend b){
 | 
					    public SameDiffTests(Nd4jBackend b) {
 | 
				
			||||||
        super(b);
 | 
					        super(b);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Override
 | 
					    @Override
 | 
				
			||||||
    public char ordering(){
 | 
					    public char ordering() {
 | 
				
			||||||
        return 'c';
 | 
					        return 'c';
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -317,7 +328,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        SameDiff first = SameDiff.create();
 | 
					        SameDiff first = SameDiff.create();
 | 
				
			||||||
        SameDiff second = SameDiff.create();
 | 
					        SameDiff second = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable firstVar = first.var("one", new long[]{2, 2});
 | 
					        SDVariable firstVar = first.var("one", new long[]{2, 2});
 | 
				
			||||||
        SDVariable secondVar = second.var(firstVar);
 | 
					        SDVariable secondVar = second.var(firstVar);
 | 
				
			||||||
        assertTrue(firstVar.getArr() == secondVar.getArr());
 | 
					        assertTrue(firstVar.getArr() == secondVar.getArr());
 | 
				
			||||||
@ -330,7 +340,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        SameDiff first = SameDiff.create();
 | 
					        SameDiff first = SameDiff.create();
 | 
				
			||||||
        SameDiff second = SameDiff.create();
 | 
					        SameDiff second = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable firstVar = first.var("one", new long[]{2, 2});
 | 
					        SDVariable firstVar = first.var("one", new long[]{2, 2});
 | 
				
			||||||
        SDVariable secondVar = second.var(firstVar);
 | 
					        SDVariable secondVar = second.var(firstVar);
 | 
				
			||||||
        assumeNotNull(firstVar.getArr());
 | 
					        assumeNotNull(firstVar.getArr());
 | 
				
			||||||
@ -418,7 +427,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }, xAndY);
 | 
					        }, xAndY);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0);
 | 
					        INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0);
 | 
				
			||||||
        INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25);
 | 
					        INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25);
 | 
				
			||||||
        assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult());
 | 
					        assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult());
 | 
				
			||||||
@ -463,7 +471,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        }, inputs);
 | 
					        }, inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray assertion = sumInput.sum(1);
 | 
					        INDArray assertion = sumInput.sum(1);
 | 
				
			||||||
        INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")).get("sum");
 | 
					        INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum"))
 | 
				
			||||||
 | 
					                .get("sum");
 | 
				
			||||||
        assertEquals(assertion, out);
 | 
					        assertEquals(assertion, out);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -563,7 +572,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }, inputVars);
 | 
					        }, inputVars);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        //1 input plus 2 outputs
 | 
					        //1 input plus 2 outputs
 | 
				
			||||||
        assertEquals(3, functionDef.variables().size());
 | 
					        assertEquals(3, functionDef.variables().size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -573,7 +581,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testIfStatementTrueBodyBackwards() {
 | 
					    public void testIfStatementTrueBodyBackwards() {
 | 
				
			||||||
        OpValidationSuite.ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
					        OpValidationSuite
 | 
				
			||||||
 | 
					                .ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
				
			||||||
        SameDiff sameDiff = SameDiff.create();
 | 
					        SameDiff sameDiff = SameDiff.create();
 | 
				
			||||||
        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
            @Override
 | 
					            @Override
 | 
				
			||||||
@ -584,7 +593,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
            @Override
 | 
					            @Override
 | 
				
			||||||
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
					            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
				
			||||||
@ -607,7 +615,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
 | 
					        sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
 | 
				
			||||||
        sameDiff.execBackwards(Collections.emptyMap());
 | 
					        sameDiff.execBackwards(Collections.emptyMap());
 | 
				
			||||||
        SameDiff grad = sameDiff.getFunction("grad");
 | 
					        SameDiff grad = sameDiff.getFunction("grad");
 | 
				
			||||||
@ -625,7 +632,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testIfStatementTrueBody() {
 | 
					    public void testIfStatementTrueBody() {
 | 
				
			||||||
        OpValidationSuite.ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
					        OpValidationSuite
 | 
				
			||||||
 | 
					                .ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
				
			||||||
        SameDiff sameDiff = SameDiff.create();
 | 
					        SameDiff sameDiff = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
@ -637,7 +645,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
            @Override
 | 
					            @Override
 | 
				
			||||||
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
					            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
				
			||||||
@ -660,7 +667,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
 | 
					        sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
 | 
				
			||||||
        sameDiff.exec(Collections.emptyMap());
 | 
					        sameDiff.exec(Collections.emptyMap());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -668,7 +674,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testIfStatementFalseBody() {
 | 
					    public void testIfStatementFalseBody() {
 | 
				
			||||||
        OpValidationSuite.ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
					        OpValidationSuite
 | 
				
			||||||
 | 
					                .ignoreFailing();      //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
 | 
				
			||||||
        SameDiff sameDiff = SameDiff.create();
 | 
					        SameDiff sameDiff = SameDiff.create();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
@ -680,7 +687,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
					        SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
 | 
				
			||||||
            @Override
 | 
					            @Override
 | 
				
			||||||
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
					            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
 | 
				
			||||||
@ -697,7 +703,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        //false body trigger
 | 
					        //false body trigger
 | 
				
			||||||
        SDVariable[] secondInputs = new SDVariable[]{
 | 
					        SDVariable[] secondInputs = new SDVariable[]{
 | 
				
			||||||
                sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1}))
 | 
					                sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1}))
 | 
				
			||||||
@ -790,7 +795,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        SDVariable weights = sd.var("W", new long[]{nIn, nOut});
 | 
					        SDVariable weights = sd.var("W", new long[]{nIn, nOut});
 | 
				
			||||||
        SDVariable bias = sd.var("b", new long[]{1, nOut});
 | 
					        SDVariable bias = sd.var("b", new long[]{1, nOut});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable mmul = sd.mmul("mmul", input, weights);
 | 
					        SDVariable mmul = sd.mmul("mmul", input, weights);
 | 
				
			||||||
        SDVariable z = mmul.add("z", bias);
 | 
					        SDVariable z = mmul.add("z", bias);
 | 
				
			||||||
        SDVariable out = sd.math().tanh(z);
 | 
					        SDVariable out = sd.math().tanh(z);
 | 
				
			||||||
@ -888,7 +892,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        val f = m.add(2.0);
 | 
					        val f = m.add(2.0);
 | 
				
			||||||
        val s = in2.add(5.0);
 | 
					        val s = in2.add(5.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        val arr = sd.execSingle(null, s.getVarName());
 | 
					        val arr = sd.execSingle(null, s.getVarName());
 | 
				
			||||||
        log.info("Result M: {}", m.getArr());
 | 
					        log.info("Result M: {}", m.getArr());
 | 
				
			||||||
        log.info("Result F: {}", f.getArr());
 | 
					        log.info("Result F: {}", f.getArr());
 | 
				
			||||||
@ -939,7 +942,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
 | 
					        val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
 | 
				
			||||||
        val input1 = sd.var("input", matrix);
 | 
					        val input1 = sd.var("input", matrix);
 | 
				
			||||||
        val input2 = sd.var("input2", vector);
 | 
					        val input2 = sd.var("input2", vector);
 | 
				
			||||||
        val output = sd.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build());
 | 
					        val output = sd
 | 
				
			||||||
 | 
					                .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build());
 | 
				
			||||||
        output.eval();
 | 
					        output.eval();
 | 
				
			||||||
        assertArrayEquals(new long[]{3, 1}, output.getShape());
 | 
					        assertArrayEquals(new long[]{3, 1}, output.getShape());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -1026,12 +1030,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }, inputs);
 | 
					        }, inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions");
 | 
					        SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions");
 | 
				
			||||||
        Map<String, INDArray> inputsSubset = new HashMap<>();
 | 
					        Map<String, INDArray> inputsSubset = new HashMap<>();
 | 
				
			||||||
        inputsSubset.put("y", inputs.get("y"));
 | 
					        inputsSubset.put("y", inputs.get("y"));
 | 
				
			||||||
        INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub");
 | 
					        INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub");
 | 
				
			||||||
        INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4,1});
 | 
					        INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4, 1});
 | 
				
			||||||
        assertEquals(assertion, output);
 | 
					        assertEquals(assertion, output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -1076,7 +1079,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        }, inputs);
 | 
					        }, inputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions");
 | 
					        SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions");
 | 
				
			||||||
        List<String> logisticOpNameAssertions = Arrays.asList("mmul", "sigmoid");
 | 
					        List<String> logisticOpNameAssertions = Arrays.asList("mmul", "sigmoid");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1146,7 +1148,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
                Activation.SOFTPLUS,
 | 
					                Activation.SOFTPLUS,
 | 
				
			||||||
                Activation.SOFTSIGN,
 | 
					                Activation.SOFTSIGN,
 | 
				
			||||||
                Activation.HARDTANH,
 | 
					                Activation.HARDTANH,
 | 
				
			||||||
                Activation.CUBE,            //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426
 | 
					                Activation.CUBE,
 | 
				
			||||||
 | 
					                //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426
 | 
				
			||||||
                Activation.RELU,            //JVM crash
 | 
					                Activation.RELU,            //JVM crash
 | 
				
			||||||
                Activation.LEAKYRELU        //JVM crash
 | 
					                Activation.LEAKYRELU        //JVM crash
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
@ -1289,8 +1292,9 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        sd.exec(Collections.emptyMap(), sd.outputs());
 | 
					        sd.exec(Collections.emptyMap(), sd.outputs());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 4; i++)
 | 
					        for (int i = 0; i < 4; i++) {
 | 
				
			||||||
            assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0));
 | 
					            assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1327,7 +1331,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        INDArray means = Nd4j.create(new float[]{2, 4}, new long[]{1, 2});
 | 
					        INDArray means = Nd4j.create(new float[]{2, 4}, new long[]{1, 2});
 | 
				
			||||||
        INDArray vars = Nd4j.create(new float[]{6, 8}, new long[]{1, 2});
 | 
					        INDArray vars = Nd4j.create(new float[]{6, 8}, new long[]{1, 2});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable sdCounts = sd.var("counts", counts);
 | 
					        SDVariable sdCounts = sd.var("counts", counts);
 | 
				
			||||||
        SDVariable sdMeans = sd.var("means", means);
 | 
					        SDVariable sdMeans = sd.var("means", means);
 | 
				
			||||||
        SDVariable sdVars = sd.var("vars", vars);
 | 
					        SDVariable sdVars = sd.var("vars", vars);
 | 
				
			||||||
@ -1363,7 +1366,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        int imgH = 28;
 | 
					        int imgH = 28;
 | 
				
			||||||
        int imgW = 28;
 | 
					        int imgW = 28;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        INDArray depthWeightArr = Nd4j.create(kH, kW, nIn, depthWise);
 | 
					        INDArray depthWeightArr = Nd4j.create(kH, kW, nIn, depthWise);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1720,7 +1722,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            SDVariable in1 = sd.var("in1", ia);
 | 
					            SDVariable in1 = sd.var("in1", ia);
 | 
				
			||||||
            SDVariable in2 = sd.var("in2", ib);
 | 
					            SDVariable in2 = sd.var("in2", ib);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
            SDVariable t;
 | 
					            SDVariable t;
 | 
				
			||||||
            INDArray expOut;
 | 
					            INDArray expOut;
 | 
				
			||||||
            switch (i) {
 | 
					            switch (i) {
 | 
				
			||||||
@ -1835,7 +1836,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        val origShape = new long[]{3, 4};
 | 
					        val origShape = new long[]{3, 4};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 3; i++) {
 | 
					        for (int i = 0; i < 3; i++) {
 | 
				
			||||||
            for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
 | 
					            for (Pair<INDArray, String> p : NDArrayCreationUtil
 | 
				
			||||||
 | 
					                    .getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
 | 
				
			||||||
                INDArray inArr = p.getFirst().muli(100);
 | 
					                INDArray inArr = p.getFirst().muli(100);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                SameDiff sd = SameDiff.create();
 | 
					                SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -1875,7 +1877,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            val shape = origShape.clone();
 | 
					            val shape = origShape.clone();
 | 
				
			||||||
            shape[i] = 1;
 | 
					            shape[i] = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
 | 
					            for (Pair<INDArray, String> p : NDArrayCreationUtil
 | 
				
			||||||
 | 
					                    .getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
 | 
				
			||||||
                INDArray inArr = p.getFirst().muli(100);
 | 
					                INDArray inArr = p.getFirst().muli(100);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                SameDiff sd = SameDiff.create();
 | 
					                SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -1912,7 +1915,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        val origShape = new long[]{3, 4};
 | 
					        val origShape = new long[]{3, 4};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int i = 0; i < 3; i++) {
 | 
					        for (int i = 0; i < 3; i++) {
 | 
				
			||||||
            for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
 | 
					            for (Pair<INDArray, String> p : NDArrayCreationUtil
 | 
				
			||||||
 | 
					                    .getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
 | 
				
			||||||
                INDArray inArr = p.getFirst().muli(100);
 | 
					                INDArray inArr = p.getFirst().muli(100);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                SameDiff sd = SameDiff.create();
 | 
					                SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -1939,7 +1943,8 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
            val shape = origShape.clone();
 | 
					            val shape = origShape.clone();
 | 
				
			||||||
            shape[i] = 1;
 | 
					            shape[i] = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
 | 
					            for (Pair<INDArray, String> p : NDArrayCreationUtil
 | 
				
			||||||
 | 
					                    .getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
 | 
				
			||||||
                INDArray inArr = p.getFirst().muli(100);
 | 
					                INDArray inArr = p.getFirst().muli(100);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                SameDiff sd = SameDiff.create();
 | 
					                SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -2214,7 +2219,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        SDVariable in = sd.var("in", 1, 2);
 | 
					        SDVariable in = sd.var("in", 1, 2);
 | 
				
			||||||
        sd.associateArrayWithVariable(ia, in);
 | 
					        sd.associateArrayWithVariable(ia, in);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray expFinite = Nd4j.create(new boolean[]{true, true});
 | 
					        INDArray expFinite = Nd4j.create(new boolean[]{true, true});
 | 
				
			||||||
        SDVariable finite = sd.math().isFinite(in);
 | 
					        SDVariable finite = sd.math().isFinite(in);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2259,11 +2263,10 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        SDVariable result2 = x.get(SDIndex.point(4), SDIndex.all());
 | 
					        SDVariable result2 = x.get(SDIndex.point(4), SDIndex.all());
 | 
				
			||||||
        assertEquals(expOut2, result2.eval());
 | 
					        assertEquals(expOut2, result2.eval());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)).reshape(5,10);
 | 
					        INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)).reshape(5, 10);
 | 
				
			||||||
        SDVariable result3 = x.get(SDIndex.interval(3, 8));
 | 
					        SDVariable result3 = x.get(SDIndex.interval(3, 8));
 | 
				
			||||||
        assertEquals(expOut3, result3.eval());
 | 
					        assertEquals(expOut3, result3.eval());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5);
 | 
					        INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5);
 | 
				
			||||||
        SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8));
 | 
					        SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8));
 | 
				
			||||||
        assertEquals(expOut4, result4.eval());
 | 
					        assertEquals(expOut4, result4.eval());
 | 
				
			||||||
@ -2295,7 +2298,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        INDArray s3a = s3.eval();
 | 
					        INDArray s3a = s3.eval();
 | 
				
			||||||
        assertEquals(s3a, y3);
 | 
					        assertEquals(s3a, y3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray y4 = arr.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.interval(3, 5));
 | 
					        INDArray y4 = arr.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.interval(3, 5));
 | 
				
			||||||
        SDVariable s4 = x.get(SDIndex.point(2), SDIndex.all(), SDIndex.interval(3, 5));
 | 
					        SDVariable s4 = x.get(SDIndex.point(2), SDIndex.all(), SDIndex.interval(3, 5));
 | 
				
			||||||
        INDArray s4a = s4.eval();
 | 
					        INDArray s4a = s4.eval();
 | 
				
			||||||
@ -2409,7 +2411,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
                new int[]{3, 2, 4});
 | 
					                new int[]{3, 2, 4});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable x = sd.var(arr);
 | 
					        SDVariable x = sd.var(arr);
 | 
				
			||||||
        SDVariable result = sd.permute(x, 1, 0, 2);
 | 
					        SDVariable result = sd.permute(x, 1, 0, 2);
 | 
				
			||||||
        assertEquals(expOut, result.eval());
 | 
					        assertEquals(expOut, result.eval());
 | 
				
			||||||
@ -2470,7 +2471,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        ExternalErrorsFunction fn = sd.f().externalErrors(out);
 | 
					        ExternalErrorsFunction fn = sd.f().externalErrors(out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sd.execAndEndResult();
 | 
					        sd.execAndEndResult();
 | 
				
			||||||
        Map<String,INDArray> m = new HashMap<>();
 | 
					        Map<String, INDArray> m = new HashMap<>();
 | 
				
			||||||
        m.put("out-grad", externalGrad);
 | 
					        m.put("out-grad", externalGrad);
 | 
				
			||||||
        sd.execBackwards(m);
 | 
					        sd.execBackwards(m);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2488,7 +2489,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assertEquals(externalGrad.mul(0.5), gradVar);
 | 
					        assertEquals(externalGrad.mul(0.5), gradVar);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        //Test model serialization:
 | 
					        //Test model serialization:
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2620,7 +2620,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        b.setArray(bA);
 | 
					        b.setArray(bA);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray grad = Nd4j.linspace(1, 12, 12, DataType.FLOAT).reshape(3, 4);
 | 
					        INDArray grad = Nd4j.linspace(1, 12, 12, DataType.FLOAT).reshape(3, 4);
 | 
				
			||||||
        Map<String,INDArray> phMap = new HashMap<>();
 | 
					        Map<String, INDArray> phMap = new HashMap<>();
 | 
				
			||||||
        phMap.put(fn.getGradPlaceholderName(), grad);
 | 
					        phMap.put(fn.getGradPlaceholderName(), grad);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        log.info("--------------- sd.execAndEndResult() ---------------");
 | 
					        log.info("--------------- sd.execAndEndResult() ---------------");
 | 
				
			||||||
@ -2723,7 +2723,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        sd.setTrainingConfig(c);
 | 
					        sd.setTrainingConfig(c);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1);
 | 
					        sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray out = tanh.eval();
 | 
					        INDArray out = tanh.eval();
 | 
				
			||||||
@ -2757,7 +2756,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3);
 | 
					        INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3);
 | 
				
			||||||
        in.setArray(inArr);
 | 
					        in.setArray(inArr);
 | 
				
			||||||
        INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3,4);
 | 
					        INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        TrainingConfig c = TrainingConfig.builder()
 | 
					        TrainingConfig c = TrainingConfig.builder()
 | 
				
			||||||
                .updater(new Adam(0.1))
 | 
					                .updater(new Adam(0.1))
 | 
				
			||||||
@ -2767,7 +2766,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
                .build();
 | 
					                .build();
 | 
				
			||||||
        sd.setTrainingConfig(c);
 | 
					        sd.setTrainingConfig(c);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1);
 | 
					        sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray out = tanh.eval();
 | 
					        INDArray out = tanh.eval();
 | 
				
			||||||
@ -2859,7 +2857,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        final INDArray out = Nd4j.concat(2, output).norm2();
 | 
					        final INDArray out = Nd4j.concat(2, output).norm2();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        final SDVariable sdInput = sd.var("input", input);
 | 
					        final SDVariable sdInput = sd.var("input", input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2905,7 +2902,6 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        final INDArray out = Nd4j.concat(2, output).norm2();
 | 
					        final INDArray out = Nd4j.concat(2, output).norm2();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        final SDVariable sdInput = sd.var("input", input);
 | 
					        final SDVariable sdInput = sd.var("input", input);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -2917,13 +2913,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        outputSlices[0] = x_0;
 | 
					        outputSlices[0] = x_0;
 | 
				
			||||||
        outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2);
 | 
					        outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        final val x_1 = inputSlices[1];
 | 
					        final val x_1 = inputSlices[1];
 | 
				
			||||||
        outputSlices[1] = x_1;
 | 
					        outputSlices[1] = x_1;
 | 
				
			||||||
        outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2));
 | 
					        outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2));
 | 
				
			||||||
        outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2);
 | 
					        outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        SDVariable t = sd.concat(2, outputSlices);
 | 
					        SDVariable t = sd.concat(2, outputSlices);
 | 
				
			||||||
        t.norm2("out");
 | 
					        t.norm2("out");
 | 
				
			||||||
        String err = OpValidation.validate(new TestCase(sd)
 | 
					        String err = OpValidation.validate(new TestCase(sd)
 | 
				
			||||||
@ -3036,7 +3030,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testSameDiffBackprop1(){
 | 
					    public void testSameDiffBackprop1() {
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        final SDVariable a = sd.var("a", Nd4j.rand(4, 4));
 | 
					        final SDVariable a = sd.var("a", Nd4j.rand(4, 4));
 | 
				
			||||||
        final SDVariable b = sd.var("b", Nd4j.rand(4, 4));
 | 
					        final SDVariable b = sd.var("b", Nd4j.rand(4, 4));
 | 
				
			||||||
@ -3050,7 +3044,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testSameDiffNoGradForConstantAndPlaceholder(){
 | 
					    public void testSameDiffNoGradForConstantAndPlaceholder() {
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        final SDVariable a = sd.var("a", Nd4j.rand(4, 4));
 | 
					        final SDVariable a = sd.var("a", Nd4j.rand(4, 4));
 | 
				
			||||||
        final SDVariable b = sd.constant("b", Nd4j.rand(4, 4));
 | 
					        final SDVariable b = sd.constant("b", Nd4j.rand(4, 4));
 | 
				
			||||||
@ -3058,16 +3052,16 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        a.add(b.add(c)).sum().markAsLoss();
 | 
					        a.add(b.add(c)).sum().markAsLoss();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4,4 )));
 | 
					        sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4, 4)));
 | 
				
			||||||
        assertNotNull(sd.grad("a"));
 | 
					        assertNotNull(sd.grad("a"));
 | 
				
			||||||
        assertNull(sd.grad("b"));
 | 
					        assertNull(sd.grad("b"));
 | 
				
			||||||
        assertNull(sd.grad("c"));
 | 
					        assertNull(sd.grad("c"));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testDuplicateNamePlaceholder(){
 | 
					    public void testDuplicateNamePlaceholder() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for( int i=0; i<2; i++ ) {
 | 
					        for (int i = 0; i < 2; i++) {
 | 
				
			||||||
            SameDiff sd = SameDiff.create();
 | 
					            SameDiff sd = SameDiff.create();
 | 
				
			||||||
            SDVariable x1 = i == 0 ? sd.placeHolder("a", DataType.FLOAT, 5, 3) : sd.var("a", DataType.FLOAT, 5, 3);
 | 
					            SDVariable x1 = i == 0 ? sd.placeHolder("a", DataType.FLOAT, 5, 3) : sd.var("a", DataType.FLOAT, 5, 3);
 | 
				
			||||||
            SDVariable x2 = i == 0 ? sd.placeHolder("b", DataType.FLOAT, 5, 3) : sd.var("b", DataType.FLOAT, 5, 3);
 | 
					            SDVariable x2 = i == 0 ? sd.placeHolder("b", DataType.FLOAT, 5, 3) : sd.var("b", DataType.FLOAT, 5, 3);
 | 
				
			||||||
@ -3119,7 +3113,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testSameDiffGetArrayScalar(){
 | 
					    public void testSameDiffGetArrayScalar() {
 | 
				
			||||||
        final INDArray array = Nd4j.rand(1, 1);
 | 
					        final INDArray array = Nd4j.rand(1, 1);
 | 
				
			||||||
        final SameDiff sd = SameDiff.create();
 | 
					        final SameDiff sd = SameDiff.create();
 | 
				
			||||||
        final SDVariable a = sd.var("a", array.shape());
 | 
					        final SDVariable a = sd.var("a", array.shape());
 | 
				
			||||||
@ -3128,11 +3122,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testVariableRenaming(){
 | 
					    public void testVariableRenaming() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        SDVariable v1 = sd.var("x", Nd4j.rand(DataType.FLOAT, 3,4));
 | 
					        SDVariable v1 = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4));
 | 
				
			||||||
        SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4,5));
 | 
					        SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5));
 | 
				
			||||||
        SDVariable v3 = v1.mmul("oldName", v2);
 | 
					        SDVariable v3 = v1.mmul("oldName", v2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray out = sd.execSingle(null, "oldName");
 | 
					        INDArray out = sd.execSingle(null, "oldName");
 | 
				
			||||||
@ -3150,11 +3144,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testVariableRenaming2(){
 | 
					    public void testVariableRenaming2() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        SDVariable v1 = sd.placeHolder("x", DataType.FLOAT,3,4);
 | 
					        SDVariable v1 = sd.placeHolder("x", DataType.FLOAT, 3, 4);
 | 
				
			||||||
        SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4,5));
 | 
					        SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5));
 | 
				
			||||||
        SDVariable v3 = v1.mmul("oldName", v2);
 | 
					        SDVariable v3 = v1.mmul("oldName", v2);
 | 
				
			||||||
        SDVariable v4 = v3.std("out", false);
 | 
					        SDVariable v4 = v3.std("out", false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -3172,7 +3166,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testPlaceholderShapeValidation(){
 | 
					    public void testPlaceholderShapeValidation() {
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4);
 | 
					        SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4);
 | 
				
			||||||
        SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4);
 | 
					        SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4);
 | 
				
			||||||
@ -3183,33 +3177,36 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        INDArray wrongShape = Nd4j.create(DataType.FLOAT, 2, 3);
 | 
					        INDArray wrongShape = Nd4j.create(DataType.FLOAT, 2, 3);
 | 
				
			||||||
        INDArray wrongRank1 = Nd4j.create(DataType.FLOAT, 1);
 | 
					        INDArray wrongRank1 = Nd4j.create(DataType.FLOAT, 1);
 | 
				
			||||||
        INDArray wrongRank2 = Nd4j.create(DataType.FLOAT, 3, 4, 5);
 | 
					        INDArray wrongRank2 = Nd4j.create(DataType.FLOAT, 3, 4, 5);
 | 
				
			||||||
        for(SDVariable v : new SDVariable[]{ph1, ph2, ph3, ph4}){
 | 
					        for (SDVariable v : new SDVariable[]{ph1, ph2, ph3, ph4}) {
 | 
				
			||||||
            v.setArray(correctShape);
 | 
					            v.setArray(correctShape);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if(v != ph4) {
 | 
					            if (v != ph4) {
 | 
				
			||||||
                try {
 | 
					                try {
 | 
				
			||||||
                    v.setArray(wrongShape);
 | 
					                    v.setArray(wrongShape);
 | 
				
			||||||
                    fail("Expected exception");
 | 
					                    fail("Expected exception");
 | 
				
			||||||
                } catch (Exception t) {
 | 
					                } catch (Exception t) {
 | 
				
			||||||
                    String msg = t.getMessage();
 | 
					                    String msg = t.getMessage();
 | 
				
			||||||
                    assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg.contains(Arrays.toString(v.placeholderShape())));
 | 
					                    assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg
 | 
				
			||||||
 | 
					                            .contains(Arrays.toString(v.placeholderShape())));
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try{
 | 
					            try {
 | 
				
			||||||
                v.setArray(wrongRank1);
 | 
					                v.setArray(wrongRank1);
 | 
				
			||||||
                fail("Expected exception");
 | 
					                fail("Expected exception");
 | 
				
			||||||
            } catch (Exception t){
 | 
					            } catch (Exception t) {
 | 
				
			||||||
                String msg = t.getMessage();
 | 
					                String msg = t.getMessage();
 | 
				
			||||||
                assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg.contains(Arrays.toString(v.placeholderShape())));
 | 
					                assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg
 | 
				
			||||||
 | 
					                        .contains(Arrays.toString(v.placeholderShape())));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            try{
 | 
					            try {
 | 
				
			||||||
                v.setArray(wrongRank2);
 | 
					                v.setArray(wrongRank2);
 | 
				
			||||||
                fail("Expected exception");
 | 
					                fail("Expected exception");
 | 
				
			||||||
            } catch (Exception t){
 | 
					            } catch (Exception t) {
 | 
				
			||||||
                String msg = t.getMessage();
 | 
					                String msg = t.getMessage();
 | 
				
			||||||
                assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg.contains(Arrays.toString(v.placeholderShape())));
 | 
					                assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg
 | 
				
			||||||
 | 
					                        .contains(Arrays.toString(v.placeholderShape())));
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -3223,9 +3220,9 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
                .markLabelsUnused()
 | 
					                .markLabelsUnused()
 | 
				
			||||||
                .updater(new Adam(1e-3)).build());
 | 
					                .updater(new Adam(1e-3)).build());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try{
 | 
					        try {
 | 
				
			||||||
            sd.fit(mds);
 | 
					            sd.fit(mds);
 | 
				
			||||||
        } catch (Exception t){
 | 
					        } catch (Exception t) {
 | 
				
			||||||
            String msg = t.getMessage();
 | 
					            String msg = t.getMessage();
 | 
				
			||||||
            assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]"));
 | 
					            assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]"));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -3233,7 +3230,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testInferenceWithoutLabel(){
 | 
					    public void testInferenceWithoutLabel() {
 | 
				
			||||||
        //We don't need a value for the label placeholder to calculate most values here
 | 
					        //We don't need a value for the label placeholder to calculate most values here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -3252,15 +3249,14 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn);
 | 
					        INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,INDArray> m = sd.exec(Collections.singletonMap("in", inputArr), "softmax");
 | 
					        Map<String, INDArray> m = sd.exec(Collections.singletonMap("in", inputArr), "softmax");
 | 
				
			||||||
        assertEquals(1, m.size());
 | 
					        assertEquals(1, m.size());
 | 
				
			||||||
        assertTrue(m.containsKey("softmax"));
 | 
					        assertTrue(m.containsKey("softmax"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray out = m.get("softmax");
 | 
					        INDArray out = m.get("softmax");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
 | 
					        INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
 | 
				
			||||||
        Map<String,INDArray> allPh = new HashMap<>();
 | 
					        Map<String, INDArray> allPh = new HashMap<>();
 | 
				
			||||||
        allPh.put("in", inputArr);
 | 
					        allPh.put("in", inputArr);
 | 
				
			||||||
        allPh.put("label", labelUnused);
 | 
					        allPh.put("label", labelUnused);
 | 
				
			||||||
        m = sd.exec(allPh, "softmax");
 | 
					        m = sd.exec(allPh, "softmax");
 | 
				
			||||||
@ -3271,7 +3267,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testInferenceWithoutUnnecessaryPlaceholders(){
 | 
					    public void testInferenceWithoutUnnecessaryPlaceholders() {
 | 
				
			||||||
        //We don't need an array for 2 of the placeholders to calculate the
 | 
					        //We don't need an array for 2 of the placeholders to calculate the
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -3293,15 +3289,14 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn);
 | 
					        INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,INDArray> m = sd.exec(Collections.singletonMap("in", inputArr), "softmax");
 | 
					        Map<String, INDArray> m = sd.exec(Collections.singletonMap("in", inputArr), "softmax");
 | 
				
			||||||
        assertEquals(1, m.size());
 | 
					        assertEquals(1, m.size());
 | 
				
			||||||
        assertTrue(m.containsKey("softmax"));
 | 
					        assertTrue(m.containsKey("softmax"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        INDArray out = m.get("softmax");
 | 
					        INDArray out = m.get("softmax");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
 | 
					        INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
 | 
				
			||||||
        Map<String,INDArray> allPh = new HashMap<>();
 | 
					        Map<String, INDArray> allPh = new HashMap<>();
 | 
				
			||||||
        allPh.put("in", inputArr);
 | 
					        allPh.put("in", inputArr);
 | 
				
			||||||
        allPh.put("label", labelUnused);
 | 
					        allPh.put("label", labelUnused);
 | 
				
			||||||
        allPh.put("in2", Nd4j.scalar(1.0f));
 | 
					        allPh.put("in2", Nd4j.scalar(1.0f));
 | 
				
			||||||
@ -3314,7 +3309,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testConvertDTypes1(){
 | 
					    public void testConvertDTypes1() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        SDVariable x = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4));
 | 
					        SDVariable x = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4));
 | 
				
			||||||
@ -3329,15 +3324,15 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        assertEquals(DataType.FLOAT, tanh.dataType());
 | 
					        assertEquals(DataType.FLOAT, tanh.dataType());
 | 
				
			||||||
        assertEquals(DataType.FLOAT, stdev.dataType());
 | 
					        assertEquals(DataType.FLOAT, stdev.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,INDArray> out = sd.exec(null, "x", "y", "z", "tanh", "stdev");
 | 
					        Map<String, INDArray> out = sd.exec(null, "x", "y", "z", "tanh", "stdev");
 | 
				
			||||||
        for(Map.Entry<String,INDArray> e : out.entrySet()){
 | 
					        for (Map.Entry<String, INDArray> e : out.entrySet()) {
 | 
				
			||||||
            assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType());
 | 
					            assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assertEquals(DataType.FLOAT, x.getArr().dataType());
 | 
					        assertEquals(DataType.FLOAT, x.getArr().dataType());
 | 
				
			||||||
        assertEquals(DataType.FLOAT, y.getArr().dataType());
 | 
					        assertEquals(DataType.FLOAT, y.getArr().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,DataType> toConvert = new HashMap<>();
 | 
					        Map<String, DataType> toConvert = new HashMap<>();
 | 
				
			||||||
        toConvert.put("x", DataType.DOUBLE);
 | 
					        toConvert.put("x", DataType.DOUBLE);
 | 
				
			||||||
        toConvert.put("y", DataType.DOUBLE);
 | 
					        toConvert.put("y", DataType.DOUBLE);
 | 
				
			||||||
        sd.convertDataTypes(toConvert);
 | 
					        sd.convertDataTypes(toConvert);
 | 
				
			||||||
@ -3349,7 +3344,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        assertEquals(DataType.DOUBLE, stdev.dataType());
 | 
					        assertEquals(DataType.DOUBLE, stdev.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        out = sd.exec(null, "x", "y", "z", "tanh", "stdev");
 | 
					        out = sd.exec(null, "x", "y", "z", "tanh", "stdev");
 | 
				
			||||||
        for(Map.Entry<String,INDArray> e : out.entrySet()){
 | 
					        for (Map.Entry<String, INDArray> e : out.entrySet()) {
 | 
				
			||||||
            assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
					            assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -3358,7 +3353,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testConvertDTypes2(){
 | 
					    public void testConvertDTypes2() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        SameDiff sd = SameDiff.create();
 | 
					        SameDiff sd = SameDiff.create();
 | 
				
			||||||
        SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3, 4);
 | 
					        SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3, 4);
 | 
				
			||||||
@ -3375,11 +3370,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        assertEquals(DataType.DOUBLE, add.dataType());
 | 
					        assertEquals(DataType.DOUBLE, add.dataType());
 | 
				
			||||||
        assertEquals(DataType.DOUBLE, relu.dataType());
 | 
					        assertEquals(DataType.DOUBLE, relu.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,INDArray> ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4));
 | 
					        Map<String, INDArray> ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,INDArray> out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r");
 | 
					        Map<String, INDArray> out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r");
 | 
				
			||||||
        for(Map.Entry<String,INDArray> e : out.entrySet()){
 | 
					        for (Map.Entry<String, INDArray> e : out.entrySet()) {
 | 
				
			||||||
            if(e.getKey().equals("x") || e.getKey().equals("y")){
 | 
					            if (e.getKey().equals("x") || e.getKey().equals("y")) {
 | 
				
			||||||
                assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType());
 | 
					                assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType());
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
					                assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
				
			||||||
@ -3388,7 +3383,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        assertEquals(DataType.FLOAT, y.getArr().dataType());
 | 
					        assertEquals(DataType.FLOAT, y.getArr().dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Map<String,DataType> toConvert = new HashMap<>();
 | 
					        Map<String, DataType> toConvert = new HashMap<>();
 | 
				
			||||||
        toConvert.put("x", DataType.DOUBLE);
 | 
					        toConvert.put("x", DataType.DOUBLE);
 | 
				
			||||||
        toConvert.put("y", DataType.DOUBLE);
 | 
					        toConvert.put("y", DataType.DOUBLE);
 | 
				
			||||||
        sd.convertDataTypes(toConvert);
 | 
					        sd.convertDataTypes(toConvert);
 | 
				
			||||||
@ -3401,7 +3396,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        assertEquals(DataType.DOUBLE, relu.dataType());
 | 
					        assertEquals(DataType.DOUBLE, relu.dataType());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r");
 | 
					        out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r");
 | 
				
			||||||
        for(Map.Entry<String,INDArray> e : out.entrySet()){
 | 
					        for (Map.Entry<String, INDArray> e : out.entrySet()) {
 | 
				
			||||||
            assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
					            assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType());
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -3410,11 +3405,11 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @Test
 | 
					    @Test
 | 
				
			||||||
    public void testGradFnRequiredVars(){
 | 
					    public void testGradFnRequiredVars() {
 | 
				
			||||||
        //User can explicitly request that gradients for specific vars are available when differentiating (creating grad function),
 | 
					        //User can explicitly request that gradients for specific vars are available when differentiating (creating grad function),
 | 
				
			||||||
        // even if they normally wouldn't be needed or calculated
 | 
					        // even if they normally wouldn't be needed or calculated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for(boolean reqPhVar : new boolean[]{false, true}){
 | 
					        for (boolean reqPhVar : new boolean[]{false, true}) {
 | 
				
			||||||
//        for(boolean reqPhVar : new boolean[]{true}){
 | 
					//        for(boolean reqPhVar : new boolean[]{true}){
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            SameDiff sd = SameDiff.create();
 | 
					            SameDiff sd = SameDiff.create();
 | 
				
			||||||
@ -3429,7 +3424,7 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            INDArray in = Nd4j.rand(DataType.FLOAT, 1, 5);
 | 
					            INDArray in = Nd4j.rand(DataType.FLOAT, 1, 5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if(reqPhVar){
 | 
					            if (reqPhVar) {
 | 
				
			||||||
                sd.createGradFunction("in");
 | 
					                sd.createGradFunction("in");
 | 
				
			||||||
                assertNotNull(ph.gradient());
 | 
					                assertNotNull(ph.gradient());
 | 
				
			||||||
                assertNotNull(w.gradient());
 | 
					                assertNotNull(w.gradient());
 | 
				
			||||||
@ -3447,6 +3442,129 @@ public class SameDiffTests extends BaseNd4jTest {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testIf() throws IOException {
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					        SDVariable a = SD.placeHolder("a", DataType.DOUBLE);
 | 
				
			||||||
 | 
					        SDVariable b = SD.var("b", Nd4j.createFromArray(5.0));
 | 
				
			||||||
 | 
					        SDVariable c = SD.var("c", Nd4j.createFromArray(9.0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable output = SD.ifCond("out", null, (sd) -> a.lt(b), (sd) -> c, (sd) -> c.add(5));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Map<String, INDArray> firstBranch = Maps.newHashMap();
 | 
				
			||||||
 | 
					        firstBranch.put("a", Nd4j.createFromArray(3.0));
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Map<String, INDArray> secondBranch = Maps.newHashMap();
 | 
				
			||||||
 | 
					        secondBranch.put("a", Nd4j.createFromArray(7.0));
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        //TODO complains that it can't deserialize a meta type, but there are no meta type ops here
 | 
				
			||||||
 | 
					        // looks like a difference between Op.Type and OpType.  Switch is saved as a OpType.LOGIC
 | 
				
			||||||
 | 
					        SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out"));
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out"));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testNestedIf() throws IOException {
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					        SDVariable a = SD.var("a", Nd4j.createFromArray(2.0));
 | 
				
			||||||
 | 
					        SDVariable b = SD.var("b", Nd4j.createFromArray(5.0));
 | 
				
			||||||
 | 
					        SDVariable c = SD.var("c", Nd4j.createFromArray(9.0));
 | 
				
			||||||
 | 
					        SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable output = SD.ifCond("out", null,
 | 
				
			||||||
 | 
					                (sd) -> a.lt(b),
 | 
				
			||||||
 | 
					                (sd) -> sd.ifCond(
 | 
				
			||||||
 | 
					                        (sd2) -> d.lte(0),
 | 
				
			||||||
 | 
					                        (sd2) -> c.add(1),
 | 
				
			||||||
 | 
					                        (sd2) -> d),
 | 
				
			||||||
 | 
					                (sd) -> c.add(5));
 | 
				
			||||||
 | 
					        INDArray out = output.eval();
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(10.0), out);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertEquals(Nd4j.createFromArray(10.0), SD.exec(null, "out").get("out"));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testWhile() throws IOException {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					        SDVariable countIn = SD.constant(5);
 | 
				
			||||||
 | 
					        SDVariable sumIn = SD.constant(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn},
 | 
				
			||||||
 | 
					                (sd, vars) -> vars[0].gt(0),
 | 
				
			||||||
 | 
					                (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        INDArray out = sum[1].eval();
 | 
				
			||||||
 | 
					        assertEquals(15, out.getInt(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        String outName = sum[1].getVarName();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertEquals(15, SD.exec(null, outName).get(outName).getInt(0));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    @Ignore
 | 
				
			||||||
 | 
					    public void testNestedWhile() throws IOException {
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					        SDVariable countIn = SD.constant(5);
 | 
				
			||||||
 | 
					        SDVariable sumIn = SD.constant(0);
 | 
				
			||||||
 | 
					        SDVariable sum2 = SD.constant(0);
 | 
				
			||||||
 | 
					        //TODO creating constant instead of using sum2 causes errors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
 | 
				
			||||||
 | 
					                (sd, vars) -> vars[0].gt(0),
 | 
				
			||||||
 | 
					                (sd, vars) -> new SDVariable[]{vars[0].sub(1),
 | 
				
			||||||
 | 
					                        vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2},
 | 
				
			||||||
 | 
					                                (sd2, vars2) -> vars2[0].gt(0),
 | 
				
			||||||
 | 
					                                (sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        INDArray out = sum[1].eval();
 | 
				
			||||||
 | 
					        assertEquals(35, out.getInt(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        String outName = sum[1].getVarName();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertEquals(35, SD.exec(null, outName).get(outName).getInt(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Test
 | 
				
			||||||
 | 
					    public void testNestedWhileIf() throws IOException {
 | 
				
			||||||
 | 
					        SameDiff SD = SameDiff.create();
 | 
				
			||||||
 | 
					        SDVariable countIn = SD.constant(5);
 | 
				
			||||||
 | 
					        SDVariable sumIn = SD.constant(0);
 | 
				
			||||||
 | 
					        SDVariable hundred = SD.constant(100);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
 | 
				
			||||||
 | 
					                (sd, vars) -> vars[0].gte(0),
 | 
				
			||||||
 | 
					                (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(
 | 
				
			||||||
 | 
					                        sd.ifCond((sd2) -> vars[0].eq(0),
 | 
				
			||||||
 | 
					                                (sd2) -> vars[0].add(100), //TODO replace with hundred and things break
 | 
				
			||||||
 | 
					                                (sd2) -> vars[0])
 | 
				
			||||||
 | 
					                )});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        INDArray out = sum[1].eval();
 | 
				
			||||||
 | 
					        assertEquals(115, out.getInt(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        String outName = sum[1].getVarName();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assertEquals(115, SD.exec(null, outName).get(outName).getInt(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user