SameDiff If, While, and Misc changes (#52)

* softmax and logSoftmax w/ dimension

Signed-off-by: Ryan Nett <rnett@skymind.io>

* start of while

Signed-off-by: Ryan Nett <rnett@skymind.io>

* if, start of javadocs

Signed-off-by: Ryan Nett <rnett@skymind.io>

* while foreward pass working, backprop WIP

Signed-off-by: Ryan Nett <rnett@skymind.io>

* no backprop

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Tensorflow style if/while (& tests), name scope fixes (and test), argument interceptor (for if/while), use '_' in op names instead of ':'

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc

Signed-off-by: Ryan Nett <rnett@skymind.io>

* many fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* many fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Some fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* cleanup if condition doesn't return boolean

Signed-off-by: Ryan Nett <rnett@skymind.io>

* serialization fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use constants instead of magic numbers

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-07-11 21:08:15 -07:00 committed by AlexDBlack
parent 2d991f5445
commit daf3950d8d
23 changed files with 1662 additions and 379 deletions

View File

@ -451,6 +451,17 @@ public abstract class DifferentialFunction {
} }
} }
public void replaceArg(int i, SDVariable newArg){
if(sameDiff != null){
sameDiff.replaceArgFor(i, newArg, this);
if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){
sameDiff.removePropertyToResolve(this, args()[i].getVarName());
} else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){
sameDiff.addPropertyToResolve(this, newArg.getVarName());
}
}
}
/** /**
* Return the output variables for this differential function. * Return the output variables for this differential function.
@ -652,9 +663,9 @@ public abstract class DifferentialFunction {
scope = ""; scope = "";
else else
scope = scope + "/"; scope = scope + "/";
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex); String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_");
while(sameDiff.functionExists(varName)) { while(sameDiff.functionExists(varName)) {
varName = scope + sameDiff.generateNewVarName(opName(), argIndex); varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_");
argIndex++; argIndex++;
} }

View File

@ -16,6 +16,11 @@
package org.nd4j.autodiff.functions; package org.nd4j.autodiff.functions;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
@ -30,36 +35,183 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.indexaccum.*; import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.loss.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.loss.bp.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
import org.nd4j.linalg.api.ops.impl.reduce.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss;
import org.nd4j.linalg.api.ops.impl.loss.HingeLoss;
import org.nd4j.linalg.api.ops.impl.loss.HuberLoss;
import org.nd4j.linalg.api.ops.impl.loss.L2Loss;
import org.nd4j.linalg.api.ops.impl.loss.LogLoss;
import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits;
import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.bp.*; import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul;
import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*; import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.reduce.same.AMax; import org.nd4j.linalg.api.ops.impl.reduce.same.AMax;
import org.nd4j.linalg.api.ops.impl.reduce.same.AMin; import org.nd4j.linalg.api.ops.impl.reduce.same.AMin;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.api.ops.impl.reduce.same.Max; import org.nd4j.linalg.api.ops.impl.reduce.same.Max;
import org.nd4j.linalg.api.ops.impl.reduce.same.Min; import org.nd4j.linalg.api.ops.impl.reduce.same.Min;
import org.nd4j.linalg.api.ops.impl.reduce.same.*; import org.nd4j.linalg.api.ops.impl.reduce.same.Prod;
import org.nd4j.linalg.api.ops.impl.reduce3.*; import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.LogX;
import org.nd4j.linalg.api.ops.impl.scalar.Pow; import org.nd4j.linalg.api.ops.impl.scalar.Pow;
import org.nd4j.linalg.api.ops.impl.scalar.*; import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scatter.*; import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.shape.*; import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Broadcast;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
import org.nd4j.linalg.api.ops.impl.shape.ExpandDims;
import org.nd4j.linalg.api.ops.impl.shape.Gather;
import org.nd4j.linalg.api.ops.impl.shape.GatherNd;
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
import org.nd4j.linalg.api.ops.impl.shape.MeshGrid;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.ParallelStack;
import org.nd4j.linalg.api.ops.impl.shape.Permute;
import org.nd4j.linalg.api.ops.impl.shape.Rank;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.ops.impl.shape.Repeat;
import org.nd4j.linalg.api.ops.impl.shape.Reshape;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.shape.Size;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.Slice;
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.StridedSlice;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
import org.nd4j.linalg.api.ops.impl.shape.ZerosLike;
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
@ -77,37 +229,165 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.api.ops.impl.transforms.same.*; import org.nd4j.linalg.api.ops.impl.transforms.same.Abs;
import org.nd4j.linalg.api.ops.impl.transforms.segment.*; import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.impl.transforms.same.Floor;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal;
import org.nd4j.linalg.api.ops.impl.transforms.same.Round;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.same.Square;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.custom.RandomNormal; import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.Range;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/** /**
* *
*/ */
@ -1611,11 +1891,24 @@ public class DifferentialFunctionFactory {
} }
public SDVariable logSoftmax(SDVariable i_v, int dimension) {
validateDifferentialFunctionsameDiff(i_v);
return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable();
}
public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) { public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) {
validateDifferentialFunctionsameDiff(arg); validateDifferentialFunctionsameDiff(arg);
return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable();
} }
public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) {
validateDifferentialFunctionsameDiff(arg);
return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable();
}
public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) {
return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable();
} }
@ -2296,6 +2589,22 @@ public class DifferentialFunctionFactory {
return tile(func, ArrayUtil.toInts(input.getShape())); return tile(func, ArrayUtil.toInts(input.getShape()));
} }
public SDVariable enter(SDVariable x, String frameName){
return new Enter(sameDiff, frameName, x).outputVariable();
}
public SDVariable enter(SDVariable x, String frameName, boolean isConstant){
return new Enter(sameDiff, frameName, x, isConstant).outputVariable();
}
public SDVariable exit(SDVariable x){
return new Exit(sameDiff, x).outputVariable();
}
public SDVariable nextIteration(SDVariable x){
return new NextIteration(sameDiff, x).outputVariable();
}
public String toString() { public String toString() {
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
/**
* Internal interface used to apply a transform to any arguments used within a certain block
*
* Intended for internal use only.
*
* Managed with {@link SameDiff#addArgumentInterceptor(ArgumentInterceptor)}, {@link SameDiff#removeArgumentInterceptor()},
* {@link SameDiff#pauseArgumentInterceptor()}, and {@link SameDiff#unpauseArgumentInterceptor()}
*
*/
public interface ArgumentInterceptor {
SDVariable intercept(SDVariable argument);
}

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import java.util.Objects;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.OnnxProto3;
@ -91,7 +92,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
String nameScope = sameDiff.currentNameScope(); String nameScope = sameDiff.currentNameScope();
if(nameScope != null){ if(nameScope != null && !varName.startsWith(nameScope + "/")){
varName = nameScope + "/" + varName; varName = nameScope + "/" + varName;
} }
@ -1785,26 +1786,6 @@ public class SDVariable extends DifferentialFunction implements Serializable {
(variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")"; (variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")";
} }
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
SDVariable that = (SDVariable) o;
if (varName != null ? !varName.equals(that.varName) : that.varName != null) return false;
return weightInitScheme != null ? weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (varName != null ? varName.hashCode() : 0);
result = 31 * result + (weightInitScheme != null ? weightInitScheme.hashCode() : 0);
return result;
}
@Override @Override
public String onnxName() { public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName()); throw new NoOpNameFoundException("No onnx op opName found for " + opName());
@ -1966,4 +1947,35 @@ public class SDVariable extends DifferentialFunction implements Serializable {
return x; return x;
} }
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SDVariable)) {
return false;
}
SDVariable that = (SDVariable) o;
if (!Objects.equals(varName, that.varName)) {
return false;
}
if (variableType != that.variableType) {
return false;
}
if(sameDiff != that.sameDiff){
return false;
}
return dataType == that.dataType;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + (varName != null ? varName.hashCode() : 0);
result = 31 * result + (variableType != null ? variableType.hashCode() : 0);
result = 31 * result + (dataType != null ? dataType.hashCode() : 0);
return result;
}
} }

View File

@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If; import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
@ -246,6 +247,14 @@ public class SameDiff extends SDBaseOps {
private boolean resolvedVariables = false; private boolean resolvedVariables = false;
@Getter
private Stack<ArgumentInterceptor> argumentInterceptors = new Stack<>();
@Getter
private Set<ArgumentInterceptor> pausedArgumentInterceptors = new HashSet<>();
private Set<String> blockNames = new HashSet<>();
@Getter @Getter
@Setter @Setter
boolean logExecution = true; boolean logExecution = true;
@ -472,7 +481,10 @@ public class SameDiff extends SDBaseOps {
if(scope == null){ if(scope == null){
return name; return name;
} }
if(!name.startsWith(scope + "/"))
return scope + "/" + name; return scope + "/" + name;
else
return name;
} }
//Intentionally package private //Intentionally package private
@ -533,6 +545,24 @@ public class SameDiff extends SDBaseOps {
} }
public List<SameDiffOp> getOpsInScope(NameScope scope){
ArrayList<SameDiffOp> ops = new ArrayList<>();
for(SameDiffOp v : this.ops.values()){
if(v.getName().startsWith(scope.getName()))
ops.add(v);
}
return ops;
}
public List<SDVariable> getVariablesInScope(NameScope scope){
ArrayList<SDVariable> vars = new ArrayList<>();
for(SDVariable v : variables()){
if(v.getVarName().startsWith(scope.getName()))
vars.add(v);
}
return vars;
}
/** /**
* @param sameDiff * @param sameDiff
* @return * @return
@ -1109,6 +1139,19 @@ public class SameDiff extends SDBaseOps {
} }
} }
/**
* Remove a property to resolve added with {@link #addPropertyToResolve(DifferentialFunction, String)}
*
* @param forFunction the function to add the property to resolve for
* @param arrayName the array name
*/
public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) {
if (propertiesToResolve.containsKey(forFunction.getOwnName())) {
List<String> newVal = propertiesToResolve.get(forFunction.getOwnName());
newVal.remove(arrayName);
}
}
/** /**
* Return the properties to resolve for the given function. * Return the properties to resolve for the given function.
* This is typically used right before execution in model import in * This is typically used right before execution in model import in
@ -1272,6 +1315,92 @@ public class SameDiff extends SDBaseOps {
} }
} }
/**
* Add a new argument interceptor to the interceptor stack
*
* For internal use only.
*
* When a op is added with arguments, most recent argument interceptor is called on it.
* If ops are added in that interceptor, the next most recent will be called on their args, and so on.
*
* @param interceptor the argument interceptor to add
*/
public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
argumentInterceptors.push(interceptor);
}
private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor){
return pausedArgumentInterceptors.contains(interceptor);
}
private ArgumentInterceptor getArgumentInterceptorToUse(){
if(argumentInterceptors.isEmpty())
return null;
ArgumentInterceptor use = argumentInterceptors.peek();
int i = 1;
while(isArgumentInterceptorPaused(use)){
if(argumentInterceptors.size() - i < 0)
return null;
use = argumentInterceptors.elementAt(argumentInterceptors.size() - i);
i++;
}
return use;
}
/**
* Remote the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void removeArgumentInterceptor(){
if(!argumentInterceptors.isEmpty())
argumentInterceptors.pop();
}
/**
* Pause the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void pauseArgumentInterceptor(){
pausedArgumentInterceptors.add(argumentInterceptors.peek());
}
/**
* Pause the given argument interceptor
*
* For internal use only.
*
* @param interceptor the argument interceptor to pause
*/
public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
pausedArgumentInterceptors.add(interceptor);
}
/**
* Unpause the top (most recently added) argument interceptor
*
* For internal use only.
*/
public void unpauseArgumentInterceptor(){
pausedArgumentInterceptors.remove(argumentInterceptors.peek());
}
/**
* Unpause the top given argument interceptor
*
* For internal use only.
*
* @param interceptor the argument interceptor to unpause
*/
public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){
pausedArgumentInterceptors.remove(interceptor);
}
/** /**
* Adds incoming arguments for the specified differential function to the graph * Adds incoming arguments for the specified differential function to the graph
* *
@ -1279,6 +1408,17 @@ public class SameDiff extends SDBaseOps {
* @param function Function * @param function Function
*/ */
public void addArgsFor(String[] variables, DifferentialFunction function) { public void addArgsFor(String[] variables, DifferentialFunction function) {
ArgumentInterceptor interceptor = getArgumentInterceptorToUse();
if(interceptor != null) {
pauseArgumentInterceptor(interceptor);
for (int i = 0; i < variables.length; i++) {
variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName();
}
unpauseArgumentInterceptor(interceptor);
}
if (function.getOwnName() == null) if (function.getOwnName() == null)
throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
@ -1309,7 +1449,6 @@ public class SameDiff extends SDBaseOps {
} }
} }
/** /**
* Adds incoming arguments for the specified differential function to the graph * Adds incoming arguments for the specified differential function to the graph
* *
@ -1317,6 +1456,7 @@ public class SameDiff extends SDBaseOps {
* @param function Function * @param function Function
*/ */
public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
String[] varNames = new String[variables.length]; String[] varNames = new String[variables.length];
for (int i = 0; i < varNames.length; i++) { for (int i = 0; i < varNames.length; i++) {
if (variables[i] == null) if (variables[i] == null)
@ -1326,6 +1466,58 @@ public class SameDiff extends SDBaseOps {
addArgsFor(varNames, function); addArgsFor(varNames, function);
} }
/**
* Replaces the argument at i with newArg for function
* Does not use (or remove) ArgumentInterceptor stuff
*/
public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function){
Preconditions.checkArgument(i < function.args().length, "Index out of range: function " +
function.getOwnName() + " only has " + function.args().length + " args but you are trying" +
"to replace the argument at " + i);
String oldName = function.arg(i).getVarName();
String newName = newArg.getVarName();
if(function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()){
boolean otherPlaceholders = false;
for(int j = 0 ; j < function.argNames().length ; j++){
if(j == i)
continue;
if(function.arg(j).isPlaceHolder())
otherPlaceholders = true;
}
if(!otherPlaceholders)
placeHolderFunctions.remove(function.getOwnName());
} else if(!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()){
if(!placeHolderFunctions.contains(function.getOwnName()))
placeHolderFunctions.add(function.getOwnName());
}
List<String> oldArgs = ops.get(function.getOwnName()).getInputsToOp();
oldArgs = new ArrayList<>(oldArgs);
oldArgs.set(i, newName);
ops.get(function.getOwnName()).setInputsToOp(oldArgs);
List<String> funcs = this.variables.get(newName).getInputsForOp();
if (funcs == null) {
funcs = new ArrayList<>();
this.variables.get(newName).setInputsForOp(funcs);
}
if(!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
funcs.add(function.getOwnName());
List<String> oldFuncs = this.variables.get(oldName).getInputsForOp();
if(oldFuncs != null) {
if(!ArrayUtils.contains(function.argNames(), oldName))
oldFuncs.remove(function.getOwnName());
}
}
/** /**
* Get the differential function (if any) that this variable is the output for * Get the differential function (if any) that this variable is the output for
* *
@ -1519,6 +1711,7 @@ public class SameDiff extends SDBaseOps {
//A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed
// by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here
// This applies to SameDiff while loops as well
if(o.getOp() instanceof Switch){ if(o.getOp() instanceof Switch){
continue; continue;
} }
@ -2239,6 +2432,7 @@ public class SameDiff extends SDBaseOps {
if (name == null || name.length() < 1) if (name == null || name.length() < 1)
name = getNewVarName(); name = getNewVarName();
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
name = v.getVarName();
variables.put(name, Variable.builder().name(name).variable(v).build()); variables.put(name, Variable.builder().name(name).variable(v).build());
constantArrays.put(name, new DeviceLocalNDArray(constant)); constantArrays.put(name, new DeviceLocalNDArray(constant));
return v; return v;
@ -2305,6 +2499,7 @@ public class SameDiff extends SDBaseOps {
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
String withScope = nameWithScope(name); String withScope = nameWithScope(name);
if (variables.containsKey(withScope)) { if (variables.containsKey(withScope)) {
if(nameScopes.isEmpty()){ if(nameScopes.isEmpty()){
throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
@ -3414,12 +3609,9 @@ public class SameDiff extends SDBaseOps {
/** /**
* Creates a while statement * @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*
* @param sameDiffConditional
* @param loopBody
* @return
*/ */
@Deprecated
public While whileStatement(SameDiffConditional sameDiffConditional, public While whileStatement(SameDiffConditional sameDiffConditional,
SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition loopBody SameDiffFunctionDefinition loopBody
@ -3435,11 +3627,9 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* @param conditional * @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
* @param trueBody
* @param falseBody
* @return
*/ */
@Deprecated
public If ifStatement(SameDiffConditional conditional, public If ifStatement(SameDiffConditional conditional,
SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition trueBody,
@ -5466,5 +5656,27 @@ public class SameDiff extends SDBaseOps {
return out; return out;
} }
/**
* For internal use only.
* Creates a new discinct block name from baseName.
* Block names are used by If and While
*/
public String newBlockName(String baseName){
if(baseName == null)
return null;
if(!blockNames.contains(baseName)){
blockNames.add(baseName);
return baseName;
} else {
int i = 1;
while(blockNames.contains(baseName + "_" + i)){
i++;
}
blockNames.add(baseName + "_" + i);
return baseName + "_" + i;
}
}
} }

View File

@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
/**
* A basic SameDiff lambda, used in while loop creation (the body).
*/
public interface SameDiffLambda {
SDVariable[] define(SameDiff sameDiff, SDVariable[] inputs);
}

View File

@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
/**
* A SameDiff lambda with only one output and no arguments. Used in if condition creation (the condition and bodies).
*/
public interface SameDiffNoArgSingleLambda {
SDVariable define(SameDiff sameDiff);
}

View File

@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff;
/**
* A SameDiff lambda with only one output, used in while loop creation (the condition).
*/
public interface SameDiffSingleLambda {
SDVariable define(SameDiff sameDiff, SDVariable[] inputs);
}

View File

@ -16,12 +16,25 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.samediff.ArgumentInterceptor;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.autodiff.samediff.SameDiffLambda;
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.shape.OneHot; import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.indexing.conditions.Condition;
@ -3142,4 +3155,304 @@ public abstract class SDBaseOps {
SDVariable ret = f().zerosLike(name, input); SDVariable ret = f().zerosLike(name, input);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* See {@link #any(String, SDVariable, int...)}
*/
public SDVariable any(SDVariable x, int... dimensions){
return any(null, x, dimensions);
}
//TODO check any w/ no dimensions
/**
* Boolean or array reduction operation, optionally along specified dimensions
*
* @param name Name of the output variable
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable any(String name, SDVariable x, int... dimensions){
validateBool("any", x);
SDVariable ret = f().any(x, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #all(String, SDVariable, int...)}
*/
public SDVariable all(SDVariable x, int... dimensions){
return all(null, x, dimensions);
}
/**
* Boolean and array reduction operation, optionally along specified dimensions
*
* @param name Name of the output variable
* @param x Input variable
* @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed
* @return Output variable: reduced array of rank (input rank - num dimensions)
*/
public SDVariable all(String name, SDVariable x, int... dimensions){
validateBool("all", x);
SDVariable ret = f().all(x, dimensions);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*/
public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
return whileLoop(null, null, loopVars, cond, body);
}
/**
* See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*/
public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
return whileLoop(null, loopName, loopVars, cond, body);
}
/**
* Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)
*
* Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false
*
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations.
*
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
*
* @param outputNames Names to give the output variables. If null, doesn't rename
* @param loopName The name of the loop block and frame (must be unique). If null, uses "if"
* @param loopVars Loop variables' inputs
* @param cond A lambda evaluating to the loop condition
* @param body A lambda doing the loop operation and returning the new loop variable values
* @return The values of the loop variables once condition is false
*/
public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars,
@NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){
final String frameName = sd().newBlockName(loopName == null ? "while" : loopName);
NameScope loopScope = sd().withNameScope(frameName);
//SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0);
SDVariable[] entered = new SDVariable[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
entered[i] = f().enter(loopVars[i], frameName);
}
//counter = SD.f().enter(counter, frameName);
SDVariable[] merged = new SDVariable[loopVars.length];
Merge[] mergeOps = new Merge[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
// the second arg will later be replaced with the output of NextIteration
// but that isn't available yet (and can't be, as it depends on this)
mergeOps[i] = new Merge(sd(), entered[i], entered[i]);
merged[i] = mergeOps[i].outputVariable();
}
//Merge counterMerge = new Merge(SD, counter, counter);
//counter = counterMerge.outputVariable();
NameScope condScope = sd().withNameScope("cond");
SDVariable cond_result = cond.define(sd(), merged);
condScope.close();
if (cond_result.dataType() != DataType.BOOL)
throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean.");
final Set<String> alreadyEntered = Sets.newHashSet();
SDVariable[] trueSwitches = new SDVariable[loopVars.length];
SDVariable[] exits = new SDVariable[loopVars.length];
for(int i = 0 ; i < loopVars.length ; i++){
SDVariable[] s = f().switchOp(merged[i], cond_result);
trueSwitches[i] = s[1];
alreadyEntered.add(s[1].getVarName());
exits[i] = f().exit(s[0]);
}
//SDVariable[] cs = SD.f().switchOp(counter, cond_result);
//SDVariable counterExit = SD.f().exit(cs[0]);
//counter = cs[1];
final Set<String> declared = Sets.newHashSet(sd().variableMap().keySet());
final Map<String, SDVariable> done = new HashMap<>();
sd().addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
if(!declared.contains(argument.getVarName()))
return argument;
if(alreadyEntered.contains(argument.getVarName()))
return argument;
if(done.containsKey(argument.getVarName()))
return done.get(argument.getVarName());
SDVariable e = f().enter(argument, frameName, true);
done.put(argument.getVarName(), e);
return e;
}
});
NameScope bodyScope = sd().withNameScope("body");
SDVariable[] outs = body.define(sd(), trueSwitches);
bodyScope.close();
sd().removeArgumentInterceptor();
//counter.add(1);
for(int i = 0 ; i < loopVars.length ; i++){
SDVariable n = f().nextIteration(outs[i]);
mergeOps[i].replaceArg(1,n);
}
//counterMerge.replaceArg(1, counter);
loopScope.close();
return updateVariableNamesAndReferences(exits, outputNames);
}
/**
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
*/
public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
return ifCond(null, null, cond, trueBody, falseBody);
}
/**
* See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
*/
public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
return ifCond(null, ifName, cond, trueBody, falseBody);
}
/**
* Constructs a If statement using the tensorflow style control flow operations (Switch and Merge)
*
* If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody
*
* Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate.
*
* See <a href="http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf">Tensorflow Control Flow Implementation</a>
*
* @param outputName Name to give the output variable. If null, doesn't rename
* @param ifName The name of the if block. If null, uses "if"
* @param cond A lambda evaluating to the if condition
* @param trueBody A lambda to be executed if cond is true (the if block)
* @param falseBody A lambda to be executed if cond is false (the else block)
* @return The value of trueBody if cond is true, or falseBody if it isn't
*/
public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond,
@NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){
ifName = sd().newBlockName(ifName == null ? "if" : ifName);
NameScope ifScope = sd().withNameScope(ifName);
NameScope condScope = sd().withNameScope("cond");
final SDVariable pred = cond.define(sd());
condScope.close();
if (pred.dataType() != DataType.BOOL) {
//cleanup partially added block
for(SDVariable v : sd().getVariablesInScope(ifScope))
sd().getVariables().remove(v.getVarName());
for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
for(String in : op.getInputsToOp()){
sd().removeArgFromFunction(in, op.getOp());
}
sd().getOps().remove(op.getName());
}
throw new IllegalStateException("Can not use " + pred.getVarName()
+ " as the condition of an If statement, the condition must be a boolean.");
}
final Map<String, SDVariable[]> switches = new HashMap<>();
final Set<String> declared = Sets.newHashSet(sd().variableMap().keySet());
sd().addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it
if(!declared.contains(argument.getVarName()))
return argument;
// if we've already added a switch, move on
if(switches.containsKey(argument.getVarName()))
return switches.get(argument.getVarName())[1];
SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.getVarName(), s);
return s[1];
}
});
NameScope trueScope = sd().withNameScope("trueBody");
SDVariable trueOut = trueBody.define(sd());
sd().removeArgumentInterceptor();
if(declared.contains(trueOut.getVarName())) {
SDVariable[] s = f().switchOp(trueOut, pred);
switches.put(trueOut.getVarName(), s);
trueOut = s[1];
}
trueScope.close();
final Set<String> declared2 = Sets.newHashSet(sd().variableMap().keySet());
sd().addArgumentInterceptor(new ArgumentInterceptor() {
@Override
public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it
if(!declared2.contains(argument.getVarName()))
return argument;
// if we've already added a switch, move on
if(switches.containsKey(argument.getVarName()))
return switches.get(argument.getVarName())[0];
SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.getVarName(), s);
return s[0];
}
});
NameScope falseScope = sd().withNameScope("falseBody");
SDVariable falseOut = falseBody.define(sd());
sd().removeArgumentInterceptor();
if(declared2.contains(falseOut.getVarName())) {
SDVariable[] s = f().switchOp(falseOut, pred);
switches.put(falseOut.getVarName(), s);
falseOut = s[0];
}
falseScope.close();
SDVariable output = f().merge(trueOut, falseOut);
ifScope.close();
return updateVariableNameAndReference(output, outputName);
}
} }

View File

@ -411,6 +411,29 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }
/**
* Log softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable logSoftmax(SDVariable x, int dimension) {
return logSoftmax(null, x, dimension);
}
/**
* Log softmax activation
*
* @param name Variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable logSoftmax(String name, SDVariable x, int dimension) {
validateFloatingPoint("log softmax", x);
SDVariable ret = f().logSoftmax(x, dimension);
return updateVariableNameAndReference(ret, name);
}
/** /**
* Element-wise rectified linear function with specified cutoff:<br> * Element-wise rectified linear function with specified cutoff:<br>
* out[i] = in[i] if in[i] >= cutoff * out[i] = in[i] if in[i] >= cutoff
@ -591,6 +614,28 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(result, name); return updateVariableNameAndReference(result, name);
} }
/**
* Softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softmax(SDVariable x, int dimension) {
return softmax(null, x, dimension);
}
/**
* Softmax activation
*
* @param x Input variable
* @return Output variable
*/
public SDVariable softmax(String name, SDVariable x, int dimension) {
validateFloatingPoint("softmax", x);
SDVariable result = f().softmax(x, dimension);
return updateVariableNameAndReference(result, name);
}
/** /**
* @param x * @param x
* @return * @return

View File

@ -17,36 +17,47 @@
package org.nd4j.autodiff.samediff.serde; package org.nd4j.autodiff.samediff.serde;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.graph.*; import org.nd4j.graph.DataType;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.OpType;
import org.nd4j.graph.VarType;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import java.nio.ByteOrder;
import java.util.*;
public class FlatBuffersMapper { public class FlatBuffersMapper {
private FlatBuffersMapper(){ } private FlatBuffersMapper() {
}
/** /**
* This method converts enums for DataType * This method converts enums for DataType
*
* @param type
* @return
*/ */
public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) {
switch (type) { switch (type) {
@ -84,80 +95,79 @@ public class FlatBuffersMapper {
/** /**
* This method converts enums for DataType * This method converts enums for DataType
*
* @param val
* @return
*/ */
public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) {
if (val == DataType.FLOAT) if (val == DataType.FLOAT) {
return org.nd4j.linalg.api.buffer.DataType.FLOAT; return org.nd4j.linalg.api.buffer.DataType.FLOAT;
else if (val == DataType.DOUBLE) } else if (val == DataType.DOUBLE) {
return org.nd4j.linalg.api.buffer.DataType.DOUBLE; return org.nd4j.linalg.api.buffer.DataType.DOUBLE;
else if (val == DataType.HALF) } else if (val == DataType.HALF) {
return org.nd4j.linalg.api.buffer.DataType.HALF; return org.nd4j.linalg.api.buffer.DataType.HALF;
else if (val == DataType.INT32) } else if (val == DataType.INT32) {
return org.nd4j.linalg.api.buffer.DataType.INT; return org.nd4j.linalg.api.buffer.DataType.INT;
else if (val == DataType.INT64) } else if (val == DataType.INT64) {
return org.nd4j.linalg.api.buffer.DataType.LONG; return org.nd4j.linalg.api.buffer.DataType.LONG;
else if (val == DataType.INT8) } else if (val == DataType.INT8) {
return org.nd4j.linalg.api.buffer.DataType.BYTE; return org.nd4j.linalg.api.buffer.DataType.BYTE;
else if (val == DataType.BOOL) } else if (val == DataType.BOOL) {
return org.nd4j.linalg.api.buffer.DataType.BOOL; return org.nd4j.linalg.api.buffer.DataType.BOOL;
else if (val == DataType.UINT8) } else if (val == DataType.UINT8) {
return org.nd4j.linalg.api.buffer.DataType.UBYTE; return org.nd4j.linalg.api.buffer.DataType.UBYTE;
else if (val == DataType.INT16) } else if (val == DataType.INT16) {
return org.nd4j.linalg.api.buffer.DataType.SHORT; return org.nd4j.linalg.api.buffer.DataType.SHORT;
else if (val == DataType.UTF8) } else if (val == DataType.UTF8) {
return org.nd4j.linalg.api.buffer.DataType.UTF8; return org.nd4j.linalg.api.buffer.DataType.UTF8;
else if (val == DataType.UINT16) } else if (val == DataType.UINT16) {
return org.nd4j.linalg.api.buffer.DataType.UINT16; return org.nd4j.linalg.api.buffer.DataType.UINT16;
else if (val == DataType.UINT32) } else if (val == DataType.UINT32) {
return org.nd4j.linalg.api.buffer.DataType.UINT32; return org.nd4j.linalg.api.buffer.DataType.UINT32;
else if (val == DataType.UINT64) } else if (val == DataType.UINT64) {
return org.nd4j.linalg.api.buffer.DataType.UINT64; return org.nd4j.linalg.api.buffer.DataType.UINT64;
else } else {
throw new RuntimeException("Unknown datatype: " + val); throw new RuntimeException("Unknown datatype: " + val);
} }
}
/** /**
* This method return operation ID for given op name/type pair. * This method return operation ID for given op name/type pair.
*
* @param name
* @param type
* @return
*/ */
public static long getOpNum(String name, Op.Type type) { public static long getOpNum(String name, Op.Type type) {
if (type == Op.Type.LOOP) { if (type == Op.Type.LOOP) {
return 0; return 0;
} else if (type == Op.Type.RETURN) { } else if (type == Op.Type.RETURN) {
return 40; return 40;
} else if (type == Op.Type.IF) {
return 30;
} else if (type == Op.Type.CONDITIONAL) { } else if (type == Op.Type.CONDITIONAL) {
return 10; return 10;
} else if (type == Op.Type.MERGE) {
return 60L;
} else if (type == Op.Type.LOOP_COND) { } else if (type == Op.Type.LOOP_COND) {
return 70L; return 70L;
} else if (type == Op.Type.NEXT_ITERATION) { } else if (type == Type.LOGIC) {
return 80L; switch (name) {
} else if (type == Op.Type.EXIT) { case Enter.OP_NAME:
return 90L; return Enter.OP_NUM;
} else if (type == Op.Type.ENTER) { case Exit.OP_NAME:
return 100L; return Exit.OP_NUM;
case NextIteration.OP_NAME:
return NextIteration.OP_NUM;
case Merge.OP_NAME:
return Merge.OP_NUM;
case Switch.OP_NAME:
return Switch.OP_NUM;
default:
throw new IllegalStateException("Unknown LOGIC op with name: " + name);
}
} else if (type == Op.Type.CUSTOM) { } else if (type == Op.Type.CUSTOM) {
val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()); val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
if (name2 == null) { if (name2 == null) {
val name3 = Nd4j.getExecutioner().getCustomOperations().get(name); val name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
if (name3 == null) if (name3 == null) {
return 0; return 0;
else } else {
return name3.getHash(); return name3.getHash();
} else }
} else {
return name2.getHash(); return name2.getHash();
}
//return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash(); //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash();
} else { } else {
@ -212,7 +222,7 @@ public class FlatBuffersMapper {
case OpType.RANDOM: case OpType.RANDOM:
return Op.Type.RANDOM; return Op.Type.RANDOM;
case OpType.LOGIC: case OpType.LOGIC:
return Op.Type.META; return Type.LOGIC;
case OpType.CUSTOM: case OpType.CUSTOM:
return Op.Type.CUSTOM; return Op.Type.CUSTOM;
case OpType.PAIRWISE: case OpType.PAIRWISE:
@ -269,15 +279,11 @@ public class FlatBuffersMapper {
return OpType.INDEX_REDUCE; return OpType.INDEX_REDUCE;
case RANDOM: case RANDOM:
return OpType.RANDOM; return OpType.RANDOM;
case MERGE:
case CONDITIONAL: case CONDITIONAL:
case LOOP: case LOOP:
case RETURN: case RETURN:
case ENTER:
case EXIT:
case NEXT_ITERATION:
case LOOP_COND: case LOOP_COND:
case IF: case LOGIC:
return OpType.LOGIC; return OpType.LOGIC;
case CUSTOM: case CUSTOM:
return OpType.CUSTOM; return OpType.CUSTOM;
@ -295,28 +301,25 @@ public class FlatBuffersMapper {
/** /**
* This method just converts enums * This method just converts enums
*
* @param val
* @return
*/ */
public static ByteOrder getOrderFromByte(byte val) { public static ByteOrder getOrderFromByte(byte val) {
if (val == org.nd4j.graph.ByteOrder.LE) if (val == org.nd4j.graph.ByteOrder.LE) {
return ByteOrder.LITTLE_ENDIAN; return ByteOrder.LITTLE_ENDIAN;
else } else {
return ByteOrder.BIG_ENDIAN; return ByteOrder.BIG_ENDIAN;
} }
}
/** /**
* This method returns current byte order for this JVM as libnd4j enum * This method returns current byte order for this JVM as libnd4j enum
*
* @return
*/ */
public static byte getOrderAsByte() { public static byte getOrderAsByte() {
if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
return org.nd4j.graph.ByteOrder.BE; return org.nd4j.graph.ByteOrder.BE;
else } else {
return org.nd4j.graph.ByteOrder.LE; return org.nd4j.graph.ByteOrder.LE;
} }
}
public static DifferentialFunction fromFlatNode(FlatNode fn) { public static DifferentialFunction fromFlatNode(FlatNode fn) {
@ -362,21 +365,23 @@ public class FlatBuffersMapper {
for (int i = 0; i < flatProperties.length; i++) { for (int i = 0; i < flatProperties.length; i++) {
flatProperties[i] = fn.properties(i); flatProperties[i] = fn.properties(i);
} }
Map<String,Object> props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties)); Map<String, Object> props = FlatBuffersMapper
.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));
if (opType == Op.Type.CUSTOM || opType == Type.LOGIC) {
if(opType == Op.Type.CUSTOM) {
String opName = fn.opName(); String opName = fn.opName();
DifferentialFunction op;
Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName); Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);
Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum); Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum);
DifferentialFunction op;
try { try {
op = (DifferentialFunction) c.newInstance(); op = (DifferentialFunction) c.newInstance();
} catch (IllegalAccessException | InstantiationException e) { } catch (IllegalAccessException | InstantiationException e) {
throw new RuntimeException("Error creating differential function instance of type " + c); throw new RuntimeException("Error creating differential function instance of type " + c);
} }
op.setOwnName(name); op.setOwnName(name);
//Set input SDVariables: //Set input SDVariables:
@ -409,8 +414,10 @@ public class FlatBuffersMapper {
if (opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL) { if (opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL) {
ScalarOp sOp = (ScalarOp) op; ScalarOp sOp = (ScalarOp) op;
sOp.setScalar(scalar); sOp.setScalar(scalar);
} else if(opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE } else if (opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS
|| opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG || opType == Op.Type.REDUCE_SAME) { || opType == Op.Type.VARIANCE
|| opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG
|| opType == Op.Type.REDUCE_SAME) {
val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations
ba.setDimensions(dimensions); ba.setDimensions(dimensions);
ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions)); ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
@ -455,8 +462,6 @@ public class FlatBuffersMapper {
int[] sIdx = null; int[] sIdx = null;
int[] shape = null; int[] shape = null;
if (v == null) { if (v == null) {
//No op //No op
} else if (v instanceof Boolean) { } else if (v instanceof Boolean) {
@ -469,7 +474,8 @@ public class FlatBuffersMapper {
} else if (v instanceof Long) { } else if (v instanceof Long) {
l = new long[]{(Long) v}; l = new long[]{(Long) v};
} else { } else {
throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); throw new UnsupportedOperationException(
"Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
} }
} else if (v instanceof String) { } else if (v instanceof String) {
String str = (String) v; String str = (String) v;
@ -501,7 +507,8 @@ public class FlatBuffersMapper {
l = (long[]) v; l = (long[]) v;
shape = new int[]{l.length}; shape = new int[]{l.length};
} else { } else {
throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); throw new UnsupportedOperationException(
"Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
} }
} else if (v instanceof String[]) { } else if (v instanceof String[]) {
//String[] //String[]
@ -537,7 +544,9 @@ public class FlatBuffersMapper {
} else if (v instanceof long[][][]) { } else if (v instanceof long[][][]) {
l = ArrayUtil.flatten((long[][][]) v); l = ArrayUtil.flatten((long[][][]) v);
} else { } else {
throw new UnsupportedOperationException("Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v.getClass()); throw new UnsupportedOperationException(
"Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v
.getClass());
} }
} }
} }
@ -550,7 +559,8 @@ public class FlatBuffersMapper {
int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT); int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT);
int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT); int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT);
outIdxs[count++] = FlatProperties.createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape); outIdxs[count++] = FlatProperties
.createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape);
} }
return outIdxs; return outIdxs;
} }

View File

@ -126,12 +126,7 @@ public class LegacyOpMapper {
case CONDITIONAL: case CONDITIONAL:
case LOOP: case LOOP:
case LOOP_COND: case LOOP_COND:
case IF:
case RETURN: case RETURN:
case ENTER:
case EXIT:
case NEXT_ITERATION:
case MERGE:
default: default:
throw new UnsupportedOperationException("Unable to map op " + opNum + " of type " + opType); throw new UnsupportedOperationException("Unable to map op " + opNum + " of type " + opType);
} }

View File

@ -25,6 +25,11 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor; import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -331,6 +336,18 @@ public class DifferentialFunctionClassHolder {
} }
public Class<?> customOpClassForHashAndName(long customOpHash, String name){ public Class<?> customOpClassForHashAndName(long customOpHash, String name){
switch (name) {
case Enter.OP_NAME:
return Enter.class;
case Exit.OP_NAME:
return Exit.class;
case NextIteration.OP_NAME:
return NextIteration.class;
case Merge.OP_NAME:
return Merge.class;
case Switch.OP_NAME:
return Switch.class;
default:
if(customOpHashToClasses.containsKey(customOpHash)){ if(customOpHashToClasses.containsKey(customOpHash)){
return customOpHashToClasses.get(customOpHash).get(name); return customOpHashToClasses.get(customOpHash).get(name);
} else if(customOpHashToClass.containsKey(customOpHash)){ } else if(customOpHashToClass.containsKey(customOpHash)){
@ -340,6 +357,8 @@ public class DifferentialFunctionClassHolder {
} }
} }
}
public static DifferentialFunctionClassHolder getInstance() { public static DifferentialFunctionClassHolder getInstance() {
return INSTANCE; return INSTANCE;
} }

View File

@ -69,14 +69,10 @@ public interface Op {
CONDITIONAL, CONDITIONAL,
LOOP, LOOP,
LOOP_COND, LOOP_COND,
IF,
RETURN, RETURN,
ENTER,
EXIT,
NEXT_ITERATION,
RANDOM, RANDOM,
MERGE,
SUMMARYSTATS, SUMMARYSTATS,
LOGIC
} }
/** /**

View File

@ -17,11 +17,13 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -32,13 +34,38 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
public class Enter extends BaseCompatOp { public class Enter extends BaseCompatOp {
protected boolean isConstant; protected boolean isConstant;
public Enter(SameDiff sameDiff, SDVariable[] inputs){
super(sameDiff, inputs);
}
public Enter(SameDiff sameDiff, String frameName, SDVariable input){
super(sameDiff, new SDVariable[]{input});
this.frameName = frameName;
isConstant = input.isConstant();
}
public Enter(SameDiff sameDiff, String frameName, SDVariable input, boolean isConstant){
super(sameDiff, new SDVariable[]{input});
this.frameName = frameName;
this.isConstant = isConstant;
}
/**
* WARNING: do not change without changing serialization methods
* See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
* and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
*/
public static final String OP_NAME = "enter";
public static final int OP_NUM = 100;
@Override @Override
public String opName() { public String opName() {
return "enter"; return OP_NAME;
} }
@Override @Override
@ -62,7 +89,7 @@ public class Enter extends BaseCompatOp {
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.ENTER; return Type.LOGIC;
} }
@Override @Override

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -34,10 +36,24 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@NoArgsConstructor
public class Exit extends BaseCompatOp { public class Exit extends BaseCompatOp {
public Exit(SameDiff sameDiff, SDVariable x) {
super(sameDiff, new SDVariable[]{x});
}
/**
* WARNING: do not change without changing serialization methods
* See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
* and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
*/
public static final String OP_NAME = "exit";
public static final int OP_NUM = 90;
@Override @Override
public String opName() { public String opName() {
return "exit"; return OP_NAME;
} }
@Override @Override
@ -61,7 +77,7 @@ public class Exit extends BaseCompatOp {
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.EXIT; return Type.LOGIC;
} }
@Override @Override

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -41,9 +42,21 @@ public class Merge extends BaseCompatOp {
} }
/**
* WARNING: do not change without changing serialization methods
* See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
* and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
*/
public static final String OP_NAME = "merge";
public static final int OP_NUM = 60;
public Merge(SameDiff sd, SDVariable a, SDVariable b){
this(sd, new SDVariable[]{a, b});
}
@Override @Override
public String opName() { public String opName() {
return "merge"; return OP_NAME;
} }
@Override @Override
@ -72,7 +85,7 @@ public class Merge extends BaseCompatOp {
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.MERGE; return Type.LOGIC;
} }
@Override @Override

View File

@ -16,11 +16,13 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -31,10 +33,24 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@NoArgsConstructor
public class NextIteration extends BaseCompatOp { public class NextIteration extends BaseCompatOp {
public NextIteration(SameDiff sameDiff, SDVariable x) {
super(sameDiff, new SDVariable[]{x});
}
/**
* WARNING: do not change without changing serialization methods
* See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
* and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
*/
public static final String OP_NAME = "next_iteration";
public static final int OP_NUM = 80;
@Override @Override
public String opName() { public String opName() {
return "next_iteration"; return OP_NAME;
} }
@Override @Override
@ -58,7 +74,7 @@ public class NextIteration extends BaseCompatOp {
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.NEXT_ITERATION; return Type.LOGIC;
} }
@Override @Override

View File

@ -16,12 +16,15 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import com.google.common.collect.Lists;
import lombok.Getter;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.Op.Type;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -37,15 +40,27 @@ import java.util.Map;
*/ */
public class Switch extends BaseCompatOp { public class Switch extends BaseCompatOp {
@Getter
private SDVariable predicate;
public Switch(SameDiff sameDiff, SDVariable input, SDVariable predicate){ public Switch(SameDiff sameDiff, SDVariable input, SDVariable predicate){
super(sameDiff, new SDVariable[]{input, predicate}); super(sameDiff, new SDVariable[]{input, predicate});
this.predicate = predicate;
} }
public Switch(){ } public Switch(){ }
/**
* WARNING: do not change without changing serialization methods
* See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)}
* and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)}
*/
public static final String OP_NAME = "switch";
public static final int OP_NUM = 30;
@Override @Override
public String opName() { public String opName() {
return "switch"; return OP_NAME;
} }
@Override @Override
@ -72,7 +87,7 @@ public class Switch extends BaseCompatOp {
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.IF; return Type.LOGIC;
} }
@Override @Override

View File

@ -39,6 +39,9 @@ import java.util.List;
*/ */
public class LogSoftMax extends DynamicCustomOp { public class LogSoftMax extends DynamicCustomOp {
private Integer dimension = null;
public LogSoftMax(SameDiff sameDiff, SDVariable i_v) { public LogSoftMax(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, i_v); super(sameDiff, i_v);
} }
@ -54,6 +57,12 @@ public class LogSoftMax extends DynamicCustomOp {
this(x, x); this(x, x);
} }
public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) {
this(sameDiff, i_v);
this.dimension = dimension;
addIArgument(dimension);
}
@Override @Override
public String opName() { public String opName() {
@ -66,8 +75,13 @@ public class LogSoftMax extends DynamicCustomOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
if(dimension == null) {
SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0));
return Collections.singletonList(ret); return Collections.singletonList(ret);
} else {
SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension);
return Collections.singletonList(ret);
}
} }
@Override @Override

View File

@ -43,6 +43,11 @@ public class LogSoftMaxDerivative extends DynamicCustomOp {
super(null, new INDArray[]{in, gradO}, new INDArray[]{out}); super(null, new INDArray[]{in, gradO}, new INDArray[]{out});
} }
public LogSoftMaxDerivative(SameDiff sameDiff, SDVariable arg, SDVariable wrt, int dimension) {
this(sameDiff, arg, wrt);
this.addIArgument(dimension);
}
/** /**
* The opName of this operation * The opName of this operation
* *

View File

@ -129,4 +129,39 @@ public class NameScopeTests extends BaseNd4jTest {
} }
} }
} }
@Test
public void testNoNesting(){
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);
NameScope scope = SD.withNameScope("test");
SDVariable out = SD.argmax(a);
out.add(45);
scope.close();
assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1"));
}
@Test
public void testNoTesting2(){
SameDiff SD = SameDiff.create();
SDVariable a = SD.constant(4);
SDVariable b = SD.constant(5).lt(4);
NameScope scope = SD.withNameScope("test");
SDVariable out = SD.f().switchOp(a, b)[0];
out.add(45);
scope.close();
assertTrue("Var with name test/switch:1 exists", SD.variableMap().containsKey("test/switch:1"));
}
} }

View File

@ -16,12 +16,30 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeNotNull;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.ClassRule; import org.junit.ClassRule;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.OpValidationSuite; import org.nd4j.OpValidationSuite;
@ -43,7 +61,11 @@ import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
@ -53,9 +75,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
@ -63,21 +83,12 @@ import org.nd4j.weightinit.impl.OneInitScheme;
import org.nd4j.weightinit.impl.UniformInitScheme; import org.nd4j.weightinit.impl.UniformInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.lang.reflect.Field;
import java.util.*;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeNotNull;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
/** /**
* Created by agibsonccc on 4/11/17. * Created by agibsonccc on 4/11/17.
*/ */
@Slf4j @Slf4j
public class SameDiffTests extends BaseNd4jTest { public class SameDiffTests extends BaseNd4jTest {
private DataType initialType; private DataType initialType;
public SameDiffTests(Nd4jBackend b) { public SameDiffTests(Nd4jBackend b) {
@ -317,7 +328,6 @@ public class SameDiffTests extends BaseNd4jTest {
SameDiff first = SameDiff.create(); SameDiff first = SameDiff.create();
SameDiff second = SameDiff.create(); SameDiff second = SameDiff.create();
SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable firstVar = first.var("one", new long[]{2, 2});
SDVariable secondVar = second.var(firstVar); SDVariable secondVar = second.var(firstVar);
assertTrue(firstVar.getArr() == secondVar.getArr()); assertTrue(firstVar.getArr() == secondVar.getArr());
@ -330,7 +340,6 @@ public class SameDiffTests extends BaseNd4jTest {
SameDiff first = SameDiff.create(); SameDiff first = SameDiff.create();
SameDiff second = SameDiff.create(); SameDiff second = SameDiff.create();
SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable firstVar = first.var("one", new long[]{2, 2});
SDVariable secondVar = second.var(firstVar); SDVariable secondVar = second.var(firstVar);
assumeNotNull(firstVar.getArr()); assumeNotNull(firstVar.getArr());
@ -418,7 +427,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}, xAndY); }, xAndY);
INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0); INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0);
INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25); INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25);
assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult()); assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult());
@ -463,7 +471,8 @@ public class SameDiffTests extends BaseNd4jTest {
}, inputs); }, inputs);
INDArray assertion = sumInput.sum(1); INDArray assertion = sumInput.sum(1);
INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")).get("sum"); INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum"))
.get("sum");
assertEquals(assertion, out); assertEquals(assertion, out);
} }
@ -563,7 +572,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}, inputVars); }, inputVars);
//1 input plus 2 outputs //1 input plus 2 outputs
assertEquals(3, functionDef.variables().size()); assertEquals(3, functionDef.variables().size());
@ -573,7 +581,8 @@ public class SameDiffTests extends BaseNd4jTest {
@Test @Test
public void testIfStatementTrueBodyBackwards() { public void testIfStatementTrueBodyBackwards() {
OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations OpValidationSuite
.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
@Override @Override
@ -584,7 +593,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}; };
SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
@ -607,7 +615,6 @@ public class SameDiffTests extends BaseNd4jTest {
}; };
sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
sameDiff.execBackwards(Collections.emptyMap()); sameDiff.execBackwards(Collections.emptyMap());
SameDiff grad = sameDiff.getFunction("grad"); SameDiff grad = sameDiff.getFunction("grad");
@ -625,7 +632,8 @@ public class SameDiffTests extends BaseNd4jTest {
@Test @Test
public void testIfStatementTrueBody() { public void testIfStatementTrueBody() {
OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations OpValidationSuite
.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
@ -637,7 +645,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}; };
SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
@ -660,7 +667,6 @@ public class SameDiffTests extends BaseNd4jTest {
}; };
sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs);
sameDiff.exec(Collections.emptyMap()); sameDiff.exec(Collections.emptyMap());
} }
@ -668,7 +674,8 @@ public class SameDiffTests extends BaseNd4jTest {
@Test @Test
public void testIfStatementFalseBody() { public void testIfStatementFalseBody() {
OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations OpValidationSuite
.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() {
@ -680,7 +687,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}; };
SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
@ -697,7 +703,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}; };
//false body trigger //false body trigger
SDVariable[] secondInputs = new SDVariable[]{ SDVariable[] secondInputs = new SDVariable[]{
sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1})) sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1}))
@ -790,7 +795,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable weights = sd.var("W", new long[]{nIn, nOut}); SDVariable weights = sd.var("W", new long[]{nIn, nOut});
SDVariable bias = sd.var("b", new long[]{1, nOut}); SDVariable bias = sd.var("b", new long[]{1, nOut});
SDVariable mmul = sd.mmul("mmul", input, weights); SDVariable mmul = sd.mmul("mmul", input, weights);
SDVariable z = mmul.add("z", bias); SDVariable z = mmul.add("z", bias);
SDVariable out = sd.math().tanh(z); SDVariable out = sd.math().tanh(z);
@ -888,7 +892,6 @@ public class SameDiffTests extends BaseNd4jTest {
val f = m.add(2.0); val f = m.add(2.0);
val s = in2.add(5.0); val s = in2.add(5.0);
val arr = sd.execSingle(null, s.getVarName()); val arr = sd.execSingle(null, s.getVarName());
log.info("Result M: {}", m.getArr()); log.info("Result M: {}", m.getArr());
log.info("Result F: {}", f.getArr()); log.info("Result F: {}", f.getArr());
@ -939,7 +942,8 @@ public class SameDiffTests extends BaseNd4jTest {
val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1);
val input1 = sd.var("input", matrix); val input1 = sd.var("input", matrix);
val input2 = sd.var("input2", vector); val input2 = sd.var("input2", vector);
val output = sd.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); val output = sd
.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build());
output.eval(); output.eval();
assertArrayEquals(new long[]{3, 1}, output.getShape()); assertArrayEquals(new long[]{3, 1}, output.getShape());
} }
@ -1026,7 +1030,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}, inputs); }, inputs);
SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions"); SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions");
Map<String, INDArray> inputsSubset = new HashMap<>(); Map<String, INDArray> inputsSubset = new HashMap<>();
inputsSubset.put("y", inputs.get("y")); inputsSubset.put("y", inputs.get("y"));
@ -1076,7 +1079,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}, inputs); }, inputs);
SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions"); SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions");
List<String> logisticOpNameAssertions = Arrays.asList("mmul", "sigmoid"); List<String> logisticOpNameAssertions = Arrays.asList("mmul", "sigmoid");
@ -1146,7 +1148,8 @@ public class SameDiffTests extends BaseNd4jTest {
Activation.SOFTPLUS, Activation.SOFTPLUS,
Activation.SOFTSIGN, Activation.SOFTSIGN,
Activation.HARDTANH, Activation.HARDTANH,
Activation.CUBE, //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426 Activation.CUBE,
//WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426
Activation.RELU, //JVM crash Activation.RELU, //JVM crash
Activation.LEAKYRELU //JVM crash Activation.LEAKYRELU //JVM crash
}; };
@ -1289,8 +1292,9 @@ public class SameDiffTests extends BaseNd4jTest {
sd.exec(Collections.emptyMap(), sd.outputs()); sd.exec(Collections.emptyMap(), sd.outputs());
for (int i = 0; i < 4; i++) for (int i = 0; i < 4; i++) {
assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0));
}
} }
@ -1327,7 +1331,6 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray means = Nd4j.create(new float[]{2, 4}, new long[]{1, 2}); INDArray means = Nd4j.create(new float[]{2, 4}, new long[]{1, 2});
INDArray vars = Nd4j.create(new float[]{6, 8}, new long[]{1, 2}); INDArray vars = Nd4j.create(new float[]{6, 8}, new long[]{1, 2});
SDVariable sdCounts = sd.var("counts", counts); SDVariable sdCounts = sd.var("counts", counts);
SDVariable sdMeans = sd.var("means", means); SDVariable sdMeans = sd.var("means", means);
SDVariable sdVars = sd.var("vars", vars); SDVariable sdVars = sd.var("vars", vars);
@ -1363,7 +1366,6 @@ public class SameDiffTests extends BaseNd4jTest {
int imgH = 28; int imgH = 28;
int imgW = 28; int imgW = 28;
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray depthWeightArr = Nd4j.create(kH, kW, nIn, depthWise); INDArray depthWeightArr = Nd4j.create(kH, kW, nIn, depthWise);
@ -1720,7 +1722,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable in1 = sd.var("in1", ia); SDVariable in1 = sd.var("in1", ia);
SDVariable in2 = sd.var("in2", ib); SDVariable in2 = sd.var("in2", ib);
SDVariable t; SDVariable t;
INDArray expOut; INDArray expOut;
switch (i) { switch (i) {
@ -1835,7 +1836,8 @@ public class SameDiffTests extends BaseNd4jTest {
val origShape = new long[]{3, 4}; val origShape = new long[]{3, 4};
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { for (Pair<INDArray, String> p : NDArrayCreationUtil
.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
INDArray inArr = p.getFirst().muli(100); INDArray inArr = p.getFirst().muli(100);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1875,7 +1877,8 @@ public class SameDiffTests extends BaseNd4jTest {
val shape = origShape.clone(); val shape = origShape.clone();
shape[i] = 1; shape[i] = 1;
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { for (Pair<INDArray, String> p : NDArrayCreationUtil
.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
INDArray inArr = p.getFirst().muli(100); INDArray inArr = p.getFirst().muli(100);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1912,7 +1915,8 @@ public class SameDiffTests extends BaseNd4jTest {
val origShape = new long[]{3, 4}; val origShape = new long[]{3, 4};
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { for (Pair<INDArray, String> p : NDArrayCreationUtil
.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) {
INDArray inArr = p.getFirst().muli(100); INDArray inArr = p.getFirst().muli(100);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1939,7 +1943,8 @@ public class SameDiffTests extends BaseNd4jTest {
val shape = origShape.clone(); val shape = origShape.clone();
shape[i] = 1; shape[i] = 1;
for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { for (Pair<INDArray, String> p : NDArrayCreationUtil
.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) {
INDArray inArr = p.getFirst().muli(100); INDArray inArr = p.getFirst().muli(100);
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -2214,7 +2219,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable in = sd.var("in", 1, 2); SDVariable in = sd.var("in", 1, 2);
sd.associateArrayWithVariable(ia, in); sd.associateArrayWithVariable(ia, in);
INDArray expFinite = Nd4j.create(new boolean[]{true, true}); INDArray expFinite = Nd4j.create(new boolean[]{true, true});
SDVariable finite = sd.math().isFinite(in); SDVariable finite = sd.math().isFinite(in);
@ -2263,7 +2267,6 @@ public class SameDiffTests extends BaseNd4jTest {
SDVariable result3 = x.get(SDIndex.interval(3, 8)); SDVariable result3 = x.get(SDIndex.interval(3, 8));
assertEquals(expOut3, result3.eval()); assertEquals(expOut3, result3.eval());
INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5); INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5);
SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8)); SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8));
assertEquals(expOut4, result4.eval()); assertEquals(expOut4, result4.eval());
@ -2295,7 +2298,6 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray s3a = s3.eval(); INDArray s3a = s3.eval();
assertEquals(s3a, y3); assertEquals(s3a, y3);
INDArray y4 = arr.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); INDArray y4 = arr.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.interval(3, 5));
SDVariable s4 = x.get(SDIndex.point(2), SDIndex.all(), SDIndex.interval(3, 5)); SDVariable s4 = x.get(SDIndex.point(2), SDIndex.all(), SDIndex.interval(3, 5));
INDArray s4a = s4.eval(); INDArray s4a = s4.eval();
@ -2409,7 +2411,6 @@ public class SameDiffTests extends BaseNd4jTest {
}, },
new int[]{3, 2, 4}); new int[]{3, 2, 4});
SDVariable x = sd.var(arr); SDVariable x = sd.var(arr);
SDVariable result = sd.permute(x, 1, 0, 2); SDVariable result = sd.permute(x, 1, 0, 2);
assertEquals(expOut, result.eval()); assertEquals(expOut, result.eval());
@ -2488,7 +2489,6 @@ public class SameDiffTests extends BaseNd4jTest {
assertEquals(externalGrad.mul(0.5), gradVar); assertEquals(externalGrad.mul(0.5), gradVar);
//Test model serialization: //Test model serialization:
} }
@ -2723,7 +2723,6 @@ public class SameDiffTests extends BaseNd4jTest {
.build(); .build();
sd.setTrainingConfig(c); sd.setTrainingConfig(c);
sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1); sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1);
INDArray out = tanh.eval(); INDArray out = tanh.eval();
@ -2767,7 +2766,6 @@ public class SameDiffTests extends BaseNd4jTest {
.build(); .build();
sd.setTrainingConfig(c); sd.setTrainingConfig(c);
sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1); sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1);
INDArray out = tanh.eval(); INDArray out = tanh.eval();
@ -2859,7 +2857,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
final INDArray out = Nd4j.concat(2, output).norm2(); final INDArray out = Nd4j.concat(2, output).norm2();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
final SDVariable sdInput = sd.var("input", input); final SDVariable sdInput = sd.var("input", input);
@ -2905,7 +2902,6 @@ public class SameDiffTests extends BaseNd4jTest {
} }
final INDArray out = Nd4j.concat(2, output).norm2(); final INDArray out = Nd4j.concat(2, output).norm2();
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
final SDVariable sdInput = sd.var("input", input); final SDVariable sdInput = sd.var("input", input);
@ -2917,13 +2913,11 @@ public class SameDiffTests extends BaseNd4jTest {
outputSlices[0] = x_0; outputSlices[0] = x_0;
outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2); outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2);
final val x_1 = inputSlices[1]; final val x_1 = inputSlices[1];
outputSlices[1] = x_1; outputSlices[1] = x_1;
outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2)); outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2));
outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2); outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2);
SDVariable t = sd.concat(2, outputSlices); SDVariable t = sd.concat(2, outputSlices);
t.norm2("out"); t.norm2("out");
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
@ -3192,7 +3186,8 @@ public class SameDiffTests extends BaseNd4jTest {
fail("Expected exception"); fail("Expected exception");
} catch (Exception t) { } catch (Exception t) {
String msg = t.getMessage(); String msg = t.getMessage();
assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg.contains(Arrays.toString(v.placeholderShape()))); assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg
.contains(Arrays.toString(v.placeholderShape())));
} }
} }
@ -3201,7 +3196,8 @@ public class SameDiffTests extends BaseNd4jTest {
fail("Expected exception"); fail("Expected exception");
} catch (Exception t) { } catch (Exception t) {
String msg = t.getMessage(); String msg = t.getMessage();
assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg.contains(Arrays.toString(v.placeholderShape()))); assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg
.contains(Arrays.toString(v.placeholderShape())));
} }
try { try {
@ -3209,7 +3205,8 @@ public class SameDiffTests extends BaseNd4jTest {
fail("Expected exception"); fail("Expected exception");
} catch (Exception t) { } catch (Exception t) {
String msg = t.getMessage(); String msg = t.getMessage();
assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg.contains(Arrays.toString(v.placeholderShape()))); assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg
.contains(Arrays.toString(v.placeholderShape())));
} }
} }
@ -3258,7 +3255,6 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray out = m.get("softmax"); INDArray out = m.get("softmax");
INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3); INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
Map<String, INDArray> allPh = new HashMap<>(); Map<String, INDArray> allPh = new HashMap<>();
allPh.put("in", inputArr); allPh.put("in", inputArr);
@ -3299,7 +3295,6 @@ public class SameDiffTests extends BaseNd4jTest {
INDArray out = m.get("softmax"); INDArray out = m.get("softmax");
INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3); INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3);
Map<String, INDArray> allPh = new HashMap<>(); Map<String, INDArray> allPh = new HashMap<>();
allPh.put("in", inputArr); allPh.put("in", inputArr);
@ -3447,6 +3442,129 @@ public class SameDiffTests extends BaseNd4jTest {
} }
}
@Test
public void testIf() throws IOException {
SameDiff SD = SameDiff.create();
SDVariable a = SD.placeHolder("a", DataType.DOUBLE);
SDVariable b = SD.var("b", Nd4j.createFromArray(5.0));
SDVariable c = SD.var("c", Nd4j.createFromArray(9.0));
SDVariable output = SD.ifCond("out", null, (sd) -> a.lt(b), (sd) -> c, (sd) -> c.add(5));
Map<String, INDArray> firstBranch = Maps.newHashMap();
firstBranch.put("a", Nd4j.createFromArray(3.0));
assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out"));
Map<String, INDArray> secondBranch = Maps.newHashMap();
secondBranch.put("a", Nd4j.createFromArray(7.0));
assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out"));
//TODO complains that it can't deserialize a meta type, but there are no meta type ops here
// looks like a difference between Op.Type and OpType. Switch is saved as a OpType.LOGIC
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out"));
assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out"));
}
@Test
public void testNestedIf() throws IOException {
SameDiff SD = SameDiff.create();
SDVariable a = SD.var("a", Nd4j.createFromArray(2.0));
SDVariable b = SD.var("b", Nd4j.createFromArray(5.0));
SDVariable c = SD.var("c", Nd4j.createFromArray(9.0));
SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0));
SDVariable output = SD.ifCond("out", null,
(sd) -> a.lt(b),
(sd) -> sd.ifCond(
(sd2) -> d.lte(0),
(sd2) -> c.add(1),
(sd2) -> d),
(sd) -> c.add(5));
INDArray out = output.eval();
assertEquals(Nd4j.createFromArray(10.0), out);
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(Nd4j.createFromArray(10.0), SD.exec(null, "out").get("out"));
}
@Test
public void testWhile() throws IOException {
SameDiff SD = SameDiff.create();
SDVariable countIn = SD.constant(5);
SDVariable sumIn = SD.constant(0);
SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gt(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])});
INDArray out = sum[1].eval();
assertEquals(15, out.getInt(0));
String outName = sum[1].getVarName();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(15, SD.exec(null, outName).get(outName).getInt(0));
}
@Test
@Ignore
public void testNestedWhile() throws IOException {
SameDiff SD = SameDiff.create();
SDVariable countIn = SD.constant(5);
SDVariable sumIn = SD.constant(0);
SDVariable sum2 = SD.constant(0);
//TODO creating constant instead of using sum2 causes errors
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gt(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1),
vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2},
(sd2, vars2) -> vars2[0].gt(0),
(sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])});
INDArray out = sum[1].eval();
assertEquals(35, out.getInt(0));
String outName = sum[1].getVarName();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(35, SD.exec(null, outName).get(outName).getInt(0));
}
@Test
public void testNestedWhileIf() throws IOException {
SameDiff SD = SameDiff.create();
SDVariable countIn = SD.constant(5);
SDVariable sumIn = SD.constant(0);
SDVariable hundred = SD.constant(100);
SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn},
(sd, vars) -> vars[0].gte(0),
(sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(
sd.ifCond((sd2) -> vars[0].eq(0),
(sd2) -> vars[0].add(100), //TODO replace with hundred and things break
(sd2) -> vars[0])
)});
INDArray out = sum[1].eval();
assertEquals(115, out.getInt(0));
String outName = sum[1].getVarName();
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(115, SD.exec(null, outName).get(outName).getInt(0));
} }
} }