diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 46cbb1523..133686b57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -451,6 +451,17 @@ public abstract class DifferentialFunction { } } + public void replaceArg(int i, SDVariable newArg){ + if(sameDiff != null){ + sameDiff.replaceArgFor(i, newArg, this); + if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){ + sameDiff.removePropertyToResolve(this, args()[i].getVarName()); + } else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){ + sameDiff.addPropertyToResolve(this, newArg.getVarName()); + } + } + } + /** * Return the output variables for this differential function. @@ -652,9 +663,9 @@ public abstract class DifferentialFunction { scope = ""; else scope = scope + "/"; - String varName = scope + sameDiff.generateNewVarName(opName(),argIndex); + String varName = scope + sameDiff.generateNewVarName(opName(),argIndex).replace(":", "_"); while(sameDiff.functionExists(varName)) { - varName = scope + sameDiff.generateNewVarName(opName(), argIndex); + varName = scope + sameDiff.generateNewVarName(opName(), argIndex).replace(":", "_"); argIndex++; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index f37c1658d..34800ca07 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -16,6 +16,11 @@ package org.nd4j.autodiff.functions; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.Data; import lombok.NonNull; import lombok.val; @@ -30,36 +35,183 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.indexaccum.*; +import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; +import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.convolution.*; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; -import org.nd4j.linalg.api.ops.impl.loss.*; -import org.nd4j.linalg.api.ops.impl.loss.bp.*; -import org.nd4j.linalg.api.ops.impl.reduce.*; +import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization; +import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss; +import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss; +import org.nd4j.linalg.api.ops.impl.loss.HingeLoss; +import org.nd4j.linalg.api.ops.impl.loss.HuberLoss; +import org.nd4j.linalg.api.ops.impl.loss.L2Loss; +import org.nd4j.linalg.api.ops.impl.loss.LogLoss; +import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss; +import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss; +import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss; +import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; +import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; +import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; +import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits; +import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss; +import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; +import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; +import org.nd4j.linalg.api.ops.impl.reduce.Moments; +import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; +import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; +import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.*; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; -import org.nd4j.linalg.api.ops.impl.reduce.floating.*; +import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; +import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy; +import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy; +import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; +import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1; +import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; +import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; +import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy; +import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.same.AMax; import org.nd4j.linalg.api.ops.impl.reduce.same.AMin; +import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; import org.nd4j.linalg.api.ops.impl.reduce.same.Max; import org.nd4j.linalg.api.ops.impl.reduce.same.Min; -import org.nd4j.linalg.api.ops.impl.reduce.same.*; -import org.nd4j.linalg.api.ops.impl.reduce3.*; +import org.nd4j.linalg.api.ops.impl.reduce.same.Prod; +import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; +import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance; +import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; +import org.nd4j.linalg.api.ops.impl.reduce3.Dot; +import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; +import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; +import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; +import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; +import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; +import org.nd4j.linalg.api.ops.impl.scalar.LogX; import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.*; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.*; -import org.nd4j.linalg.api.ops.impl.scatter.*; -import org.nd4j.linalg.api.ops.impl.shape.*; +import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.scalar.Relu6; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; +import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction; +import org.nd4j.linalg.api.ops.impl.scalar.Step; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; +import org.nd4j.linalg.api.ops.impl.shape.Broadcast; +import org.nd4j.linalg.api.ops.impl.shape.Concat; +import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; +import org.nd4j.linalg.api.ops.impl.shape.Cross; +import org.nd4j.linalg.api.ops.impl.shape.Diag; +import org.nd4j.linalg.api.ops.impl.shape.DiagPart; +import org.nd4j.linalg.api.ops.impl.shape.ExpandDims; +import org.nd4j.linalg.api.ops.impl.shape.Gather; +import org.nd4j.linalg.api.ops.impl.shape.GatherNd; +import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; +import org.nd4j.linalg.api.ops.impl.shape.MergeMax; +import org.nd4j.linalg.api.ops.impl.shape.MeshGrid; +import org.nd4j.linalg.api.ops.impl.shape.OneHot; +import org.nd4j.linalg.api.ops.impl.shape.OnesLike; +import org.nd4j.linalg.api.ops.impl.shape.ParallelStack; +import org.nd4j.linalg.api.ops.impl.shape.Permute; +import org.nd4j.linalg.api.ops.impl.shape.Rank; +import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; +import org.nd4j.linalg.api.ops.impl.shape.Repeat; +import org.nd4j.linalg.api.ops.impl.shape.Reshape; +import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; +import org.nd4j.linalg.api.ops.impl.shape.Size; +import org.nd4j.linalg.api.ops.impl.shape.SizeAt; +import org.nd4j.linalg.api.ops.impl.shape.Slice; +import org.nd4j.linalg.api.ops.impl.shape.Squeeze; +import org.nd4j.linalg.api.ops.impl.shape.Stack; +import org.nd4j.linalg.api.ops.impl.shape.StridedSlice; +import org.nd4j.linalg.api.ops.impl.shape.Tile; +import org.nd4j.linalg.api.ops.impl.shape.Transpose; +import org.nd4j.linalg.api.ops.impl.shape.Unstack; +import org.nd4j.linalg.api.ops.impl.shape.ZerosLike; import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; @@ -77,37 +229,165 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.*; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign; +import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; +import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd; +import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D; +import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention; +import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition; +import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch; +import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; +import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace; +import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB; +import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; +import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; +import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd; +import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; -import org.nd4j.linalg.api.ops.impl.transforms.same.*; -import org.nd4j.linalg.api.ops.impl.transforms.segment.*; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.*; -import org.nd4j.linalg.api.ops.impl.transforms.strict.*; +import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; +import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil; +import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; +import org.nd4j.linalg.api.ops.impl.transforms.same.Floor; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; +import org.nd4j.linalg.api.ops.impl.transforms.same.Negative; +import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal; +import org.nd4j.linalg.api.ops.impl.transforms.same.Round; +import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; +import org.nd4j.linalg.api.ops.impl.transforms.same.Square; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN; +import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1; +import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU; +import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid; +import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p; +import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid; +import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU; +import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish; +import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan; +import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.custom.RandomNormal; -import org.nd4j.linalg.api.ops.random.impl.*; +import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; +import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; +import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; +import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; +import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution; +import org.nd4j.linalg.api.ops.random.impl.Range; +import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; +import org.nd4j.linalg.api.ops.random.impl.UniformDistribution; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.util.ArrayUtil; -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - /** * */ @@ -1611,11 +1891,24 @@ public class DifferentialFunctionFactory { } + public SDVariable logSoftmax(SDVariable i_v, int dimension) { + validateDifferentialFunctionsameDiff(i_v); + return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable(); + + } + + public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) { validateDifferentialFunctionsameDiff(arg); return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); } + + public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) { + validateDifferentialFunctionsameDiff(arg); + return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable(); + } + public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); } @@ -2296,6 +2589,22 @@ public class DifferentialFunctionFactory { return tile(func, ArrayUtil.toInts(input.getShape())); } + public SDVariable enter(SDVariable x, String frameName){ + return new Enter(sameDiff, frameName, x).outputVariable(); + } + + public SDVariable enter(SDVariable x, String frameName, boolean isConstant){ + return new Enter(sameDiff, frameName, x, isConstant).outputVariable(); + } + + public SDVariable exit(SDVariable x){ + return new Exit(sameDiff, x).outputVariable(); + } + + public SDVariable nextIteration(SDVariable x){ + return new NextIteration(sameDiff, x).outputVariable(); + } + public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArgumentInterceptor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArgumentInterceptor.java new file mode 100644 index 000000000..a1f4734fc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArgumentInterceptor.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff; + +/** + * Internal interface used to apply a transform to any arguments used within a certain block + * + * Intended for internal use only. + * + * Managed with {@link SameDiff#addArgumentInterceptor(ArgumentInterceptor)}, {@link SameDiff#removeArgumentInterceptor()}, + * {@link SameDiff#pauseArgumentInterceptor()}, and {@link SameDiff#unpauseArgumentInterceptor()} + * + */ +public interface ArgumentInterceptor { + SDVariable intercept(SDVariable argument); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 33a773415..8ac789f9b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff; +import java.util.Objects; import lombok.*; import lombok.extern.slf4j.Slf4j; import onnx.OnnxProto3; @@ -91,7 +92,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); String nameScope = sameDiff.currentNameScope(); - if(nameScope != null){ + if(nameScope != null && !varName.startsWith(nameScope + "/")){ varName = nameScope + "/" + varName; } @@ -1785,26 +1786,6 @@ public class SDVariable extends DifferentialFunction implements Serializable { (variableType == VariableType.PLACEHOLDER && shape != null ? ",shape=" + Arrays.toString(shape): "") + ")"; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - if (!super.equals(o)) return false; - - SDVariable that = (SDVariable) o; - - if (varName != null ? !varName.equals(that.varName) : that.varName != null) return false; - return weightInitScheme != null ? weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null; - } - - @Override - public int hashCode() { - int result = super.hashCode(); - result = 31 * result + (varName != null ? varName.hashCode() : 0); - result = 31 * result + (weightInitScheme != null ? weightInitScheme.hashCode() : 0); - return result; - } - @Override public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); @@ -1965,5 +1946,36 @@ public class SDVariable extends DifferentialFunction implements Serializable { } return x; } - + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SDVariable)) { + return false; + } + + SDVariable that = (SDVariable) o; + + if (!Objects.equals(varName, that.varName)) { + return false; + } + if (variableType != that.variableType) { + return false; + } + if(sameDiff != that.sameDiff){ + return false; + } + return dataType == that.dataType; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + (varName != null ? varName.hashCode() : 0); + result = 31 * result + (variableType != null ? variableType.hashCode() : 0); + result = 31 * result + (dataType != null ? dataType.hashCode() : 0); + return result; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 04bc59603..2e7829913 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -53,6 +53,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.controlflow.If; import org.nd4j.linalg.api.ops.impl.controlflow.While; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; @@ -246,6 +247,14 @@ public class SameDiff extends SDBaseOps { private boolean resolvedVariables = false; + + @Getter + private Stack argumentInterceptors = new Stack<>(); + @Getter + private Set pausedArgumentInterceptors = new HashSet<>(); + + private Set blockNames = new HashSet<>(); + @Getter @Setter boolean logExecution = true; @@ -472,7 +481,10 @@ public class SameDiff extends SDBaseOps { if(scope == null){ return name; } - return scope + "/" + name; + if(!name.startsWith(scope + "/")) + return scope + "/" + name; + else + return name; } //Intentionally package private @@ -533,6 +545,24 @@ public class SameDiff extends SDBaseOps { } + public List getOpsInScope(NameScope scope){ + ArrayList ops = new ArrayList<>(); + for(SameDiffOp v : this.ops.values()){ + if(v.getName().startsWith(scope.getName())) + ops.add(v); + } + return ops; + } + + public List getVariablesInScope(NameScope scope){ + ArrayList vars = new ArrayList<>(); + for(SDVariable v : variables()){ + if(v.getVarName().startsWith(scope.getName())) + vars.add(v); + } + return vars; + } + /** * @param sameDiff * @return @@ -1109,6 +1139,19 @@ public class SameDiff extends SDBaseOps { } } + /** + * Remove a property to resolve added with {@link #addPropertyToResolve(DifferentialFunction, String)} + * + * @param forFunction the function to add the property to resolve for + * @param arrayName the array name + */ + public void removePropertyToResolve(DifferentialFunction forFunction, String arrayName) { + if (propertiesToResolve.containsKey(forFunction.getOwnName())) { + List newVal = propertiesToResolve.get(forFunction.getOwnName()); + newVal.remove(arrayName); + } + } + /** * Return the properties to resolve for the given function. * This is typically used right before execution in model import in @@ -1272,6 +1315,92 @@ public class SameDiff extends SDBaseOps { } } + /** + * Add a new argument interceptor to the interceptor stack + * + * For internal use only. + * + * When a op is added with arguments, most recent argument interceptor is called on it. + * If ops are added in that interceptor, the next most recent will be called on their args, and so on. + * + * @param interceptor the argument interceptor to add + */ + public void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + argumentInterceptors.push(interceptor); + } + + private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor interceptor){ + return pausedArgumentInterceptors.contains(interceptor); + } + + private ArgumentInterceptor getArgumentInterceptorToUse(){ + + if(argumentInterceptors.isEmpty()) + return null; + + ArgumentInterceptor use = argumentInterceptors.peek(); + int i = 1; + while(isArgumentInterceptorPaused(use)){ + if(argumentInterceptors.size() - i < 0) + return null; + + use = argumentInterceptors.elementAt(argumentInterceptors.size() - i); + i++; + } + + return use; + } + + /** + * Remote the top (most recently added) argument interceptor + * + * For internal use only. + */ + public void removeArgumentInterceptor(){ + if(!argumentInterceptors.isEmpty()) + argumentInterceptors.pop(); + } + + /** + * Pause the top (most recently added) argument interceptor + * + * For internal use only. + */ + public void pauseArgumentInterceptor(){ + pausedArgumentInterceptors.add(argumentInterceptors.peek()); + } + + /** + * Pause the given argument interceptor + * + * For internal use only. + * + * @param interceptor the argument interceptor to pause + */ + public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + pausedArgumentInterceptors.add(interceptor); + } + + /** + * Unpause the top (most recently added) argument interceptor + * + * For internal use only. + */ + public void unpauseArgumentInterceptor(){ + pausedArgumentInterceptors.remove(argumentInterceptors.peek()); + } + + /** + * Unpause the top given argument interceptor + * + * For internal use only. + * + * @param interceptor the argument interceptor to unpause + */ + public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor){ + pausedArgumentInterceptors.remove(interceptor); + } + /** * Adds incoming arguments for the specified differential function to the graph * @@ -1279,6 +1408,17 @@ public class SameDiff extends SDBaseOps { * @param function Function */ public void addArgsFor(String[] variables, DifferentialFunction function) { + + ArgumentInterceptor interceptor = getArgumentInterceptorToUse(); + + if(interceptor != null) { + pauseArgumentInterceptor(interceptor); + for (int i = 0; i < variables.length; i++) { + variables[i] = interceptor.intercept(getVariable(variables[i])).getVarName(); + } + unpauseArgumentInterceptor(interceptor); + } + if (function.getOwnName() == null) throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly"); @@ -1309,7 +1449,6 @@ public class SameDiff extends SDBaseOps { } } - /** * Adds incoming arguments for the specified differential function to the graph * @@ -1317,6 +1456,7 @@ public class SameDiff extends SDBaseOps { * @param function Function */ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { + String[] varNames = new String[variables.length]; for (int i = 0; i < varNames.length; i++) { if (variables[i] == null) @@ -1326,6 +1466,58 @@ public class SameDiff extends SDBaseOps { addArgsFor(varNames, function); } + /** + * Replaces the argument at i with newArg for function + * Does not use (or remove) ArgumentInterceptor stuff + */ + public void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function){ + + Preconditions.checkArgument(i < function.args().length, "Index out of range: function " + + function.getOwnName() + " only has " + function.args().length + " args but you are trying" + + "to replace the argument at " + i); + + String oldName = function.arg(i).getVarName(); + String newName = newArg.getVarName(); + + if(function.arg(i).isPlaceHolder() && !newArg.isPlaceHolder()){ + boolean otherPlaceholders = false; + for(int j = 0 ; j < function.argNames().length ; j++){ + if(j == i) + continue; + + if(function.arg(j).isPlaceHolder()) + otherPlaceholders = true; + } + + if(!otherPlaceholders) + placeHolderFunctions.remove(function.getOwnName()); + } else if(!function.arg(i).isPlaceHolder() && newArg.isPlaceHolder()){ + if(!placeHolderFunctions.contains(function.getOwnName())) + placeHolderFunctions.add(function.getOwnName()); + } + + List oldArgs = ops.get(function.getOwnName()).getInputsToOp(); + oldArgs = new ArrayList<>(oldArgs); + oldArgs.set(i, newName); + ops.get(function.getOwnName()).setInputsToOp(oldArgs); + + List funcs = this.variables.get(newName).getInputsForOp(); + + if (funcs == null) { + funcs = new ArrayList<>(); + this.variables.get(newName).setInputsForOp(funcs); + } + if(!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. + funcs.add(function.getOwnName()); + + List oldFuncs = this.variables.get(oldName).getInputsForOp(); + if(oldFuncs != null) { + if(!ArrayUtils.contains(function.argNames(), oldName)) + oldFuncs.remove(function.getOwnName()); + } + + } + /** * Get the differential function (if any) that this variable is the output for * @@ -1519,6 +1711,7 @@ public class SameDiff extends SDBaseOps { //A bit of a hack for TF import: some TF graphs have Switch ops, where the output of one branch isn't consumed // by any ops. Consequently, during execution this "output" might never be available. So we'll exclude the output of execution here + // This applies to SameDiff while loops as well if(o.getOp() instanceof Switch){ continue; } @@ -2239,6 +2432,7 @@ public class SameDiff extends SDBaseOps { if (name == null || name.length() < 1) name = getNewVarName(); SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); + name = v.getVarName(); variables.put(name, Variable.builder().name(name).variable(v).build()); constantArrays.put(name, new DeviceLocalNDArray(constant)); return v; @@ -2305,6 +2499,7 @@ public class SameDiff extends SDBaseOps { public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { String withScope = nameWithScope(name); + if (variables.containsKey(withScope)) { if(nameScopes.isEmpty()){ throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" @@ -3414,12 +3609,9 @@ public class SameDiff extends SDBaseOps { /** - * Creates a while statement - * - * @param sameDiffConditional - * @param loopBody - * @return + * @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} */ + @Deprecated public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition loopBody @@ -3435,11 +3627,9 @@ public class SameDiff extends SDBaseOps { } /** - * @param conditional - * @param trueBody - * @param falseBody - * @return + * @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} */ + @Deprecated public If ifStatement(SameDiffConditional conditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition trueBody, @@ -5466,5 +5656,27 @@ public class SameDiff extends SDBaseOps { return out; } + /** + * For internal use only. + * Creates a new discinct block name from baseName. + * Block names are used by If and While + */ + public String newBlockName(String baseName){ + + if(baseName == null) + return null; + + if(!blockNames.contains(baseName)){ + blockNames.add(baseName); + return baseName; + } else { + int i = 1; + while(blockNames.contains(baseName + "_" + i)){ + i++; + } + blockNames.add(baseName + "_" + i); + return baseName + "_" + i; + } + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffLambda.java new file mode 100644 index 000000000..c9efc1428 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffLambda.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff; + +/** + * A basic SameDiff lambda, used in while loop creation (the body). + */ +public interface SameDiffLambda { + SDVariable[] define(SameDiff sameDiff, SDVariable[] inputs); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffNoArgSingleLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffNoArgSingleLambda.java new file mode 100644 index 000000000..4c3f7a86d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffNoArgSingleLambda.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff; + +/** + * A SameDiff lambda with only one output and no arguments. Used in if condition creation (the condition and bodies). + */ +public interface SameDiffNoArgSingleLambda { + SDVariable define(SameDiff sameDiff); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffSingleLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffSingleLambda.java new file mode 100644 index 000000000..21ba05689 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiffSingleLambda.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff; + +/** + * A SameDiff lambda with only one output, used in while loop creation (the condition). + */ +public interface SameDiffSingleLambda { + SDVariable define(SameDiff sameDiff, SDVariable[] inputs); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 0fb5dc360..b23dd576b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -16,12 +16,25 @@ package org.nd4j.autodiff.samediff.ops; +import com.google.common.collect.Sets; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; import lombok.NonNull; import org.nd4j.autodiff.functions.DifferentialFunctionFactory; +import org.nd4j.autodiff.samediff.ArgumentInterceptor; +import org.nd4j.autodiff.samediff.NameScope; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition; +import org.nd4j.autodiff.samediff.SameDiffLambda; +import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda; +import org.nd4j.autodiff.samediff.SameDiffSingleLambda; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; import org.nd4j.linalg.api.ops.impl.shape.OneHot; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.indexing.conditions.Condition; @@ -3142,4 +3155,304 @@ public abstract class SDBaseOps { SDVariable ret = f().zerosLike(name, input); return updateVariableNameAndReference(ret, name); } + + + /** + * See {@link #any(String, SDVariable, int...)} + */ + public SDVariable any(SDVariable x, int... dimensions){ + return any(null, x, dimensions); + } + //TODO check any w/ no dimensions + + /** + * Boolean or array reduction operation, optionally along specified dimensions + * + * @param name Name of the output variable + * @param x Input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed + * @return Output variable: reduced array of rank (input rank - num dimensions) + */ + public SDVariable any(String name, SDVariable x, int... dimensions){ + validateBool("any", x); + SDVariable ret = f().any(x, dimensions); + return updateVariableNameAndReference(ret, name); + } + + + /** + * See {@link #all(String, SDVariable, int...)} + */ + public SDVariable all(SDVariable x, int... dimensions){ + return all(null, x, dimensions); + } + + + /** + * Boolean and array reduction operation, optionally along specified dimensions + * + * @param name Name of the output variable + * @param x Input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed + * @return Output variable: reduced array of rank (input rank - num dimensions) + */ + public SDVariable all(String name, SDVariable x, int... dimensions){ + validateBool("all", x); + SDVariable ret = f().all(x, dimensions); + return updateVariableNameAndReference(ret, name); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(@NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, null, loopVars, cond, body); + } + + /** + * See {@link #whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)} + */ + public SDVariable[] whileLoop(String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + return whileLoop(null, loopName, loopVars, cond, body); + } + + /** + * Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) + * + * Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. + * + * See Tensorflow Control Flow Implementation + * + * @param outputNames Names to give the output variables. If null, doesn't rename + * @param loopName The name of the loop block and frame (must be unique). If null, uses "if" + * @param loopVars Loop variables' inputs + * @param cond A lambda evaluating to the loop condition + * @param body A lambda doing the loop operation and returning the new loop variable values + * @return The values of the loop variables once condition is false + */ + public SDVariable[] whileLoop(String[] outputNames, final String loopName, @NonNull SDVariable[] loopVars, + @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body){ + + final String frameName = sd().newBlockName(loopName == null ? "while" : loopName); + + NameScope loopScope = sd().withNameScope(frameName); + + //SDVariable counter = SD.scalar(SD.generateNewVarName("counter", 0), 0); + + SDVariable[] entered = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + entered[i] = f().enter(loopVars[i], frameName); + } + + //counter = SD.f().enter(counter, frameName); + + SDVariable[] merged = new SDVariable[loopVars.length]; + Merge[] mergeOps = new Merge[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + // the second arg will later be replaced with the output of NextIteration + // but that isn't available yet (and can't be, as it depends on this) + mergeOps[i] = new Merge(sd(), entered[i], entered[i]); + merged[i] = mergeOps[i].outputVariable(); + } + + //Merge counterMerge = new Merge(SD, counter, counter); + //counter = counterMerge.outputVariable(); + + NameScope condScope = sd().withNameScope("cond"); + SDVariable cond_result = cond.define(sd(), merged); + condScope.close(); + + + if (cond_result.dataType() != DataType.BOOL) + throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean."); + + + final Set alreadyEntered = Sets.newHashSet(); + SDVariable[] trueSwitches = new SDVariable[loopVars.length]; + SDVariable[] exits = new SDVariable[loopVars.length]; + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable[] s = f().switchOp(merged[i], cond_result); + trueSwitches[i] = s[1]; + alreadyEntered.add(s[1].getVarName()); + exits[i] = f().exit(s[0]); + } + + //SDVariable[] cs = SD.f().switchOp(counter, cond_result); + //SDVariable counterExit = SD.f().exit(cs[0]); + //counter = cs[1]; + + final Set declared = Sets.newHashSet(sd().variableMap().keySet()); + final Map done = new HashMap<>(); + + sd().addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + if(!declared.contains(argument.getVarName())) + return argument; + + if(alreadyEntered.contains(argument.getVarName())) + return argument; + + if(done.containsKey(argument.getVarName())) + return done.get(argument.getVarName()); + + SDVariable e = f().enter(argument, frameName, true); + done.put(argument.getVarName(), e); + return e; + } + }); + + NameScope bodyScope = sd().withNameScope("body"); + SDVariable[] outs = body.define(sd(), trueSwitches); + bodyScope.close(); + sd().removeArgumentInterceptor(); + + //counter.add(1); + + for(int i = 0 ; i < loopVars.length ; i++){ + SDVariable n = f().nextIteration(outs[i]); + mergeOps[i].replaceArg(1,n); + } + + //counterMerge.replaceArg(1, counter); + + loopScope.close(); + return updateVariableNamesAndReferences(exits, outputNames); + } + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, null, cond, trueBody, falseBody); + } + + + /** + * See {@link #ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)} + */ + public SDVariable ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + return ifCond(null, ifName, cond, trueBody, falseBody); + } + + /** + * Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) + * + * If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody + * + * Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. + * + * See Tensorflow Control Flow Implementation + * + * @param outputName Name to give the output variable. If null, doesn't rename + * @param ifName The name of the if block. If null, uses "if" + * @param cond A lambda evaluating to the if condition + * @param trueBody A lambda to be executed if cond is true (the if block) + * @param falseBody A lambda to be executed if cond is false (the else block) + * @return The value of trueBody if cond is true, or falseBody if it isn't + */ + public SDVariable ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, + @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody){ + + ifName = sd().newBlockName(ifName == null ? "if" : ifName); + + NameScope ifScope = sd().withNameScope(ifName); + + NameScope condScope = sd().withNameScope("cond"); + final SDVariable pred = cond.define(sd()); + condScope.close(); + + if (pred.dataType() != DataType.BOOL) { + //cleanup partially added block + + for(SDVariable v : sd().getVariablesInScope(ifScope)) + sd().getVariables().remove(v.getVarName()); + + for(SameDiffOp op : sd().getOpsInScope(ifScope)) { + for(String in : op.getInputsToOp()){ + sd().removeArgFromFunction(in, op.getOp()); + } + sd().getOps().remove(op.getName()); + } + + + throw new IllegalStateException("Can not use " + pred.getVarName() + + " as the condition of an If statement, the condition must be a boolean."); + } + + final Map switches = new HashMap<>(); + + final Set declared = Sets.newHashSet(sd().variableMap().keySet()); + + sd().addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared.contains(argument.getVarName())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.getVarName())) + return switches.get(argument.getVarName())[1]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.getVarName(), s); + return s[1]; + } + }); + NameScope trueScope = sd().withNameScope("trueBody"); + SDVariable trueOut = trueBody.define(sd()); + sd().removeArgumentInterceptor(); + + if(declared.contains(trueOut.getVarName())) { + SDVariable[] s = f().switchOp(trueOut, pred); + switches.put(trueOut.getVarName(), s); + trueOut = s[1]; + } + + trueScope.close(); + + final Set declared2 = Sets.newHashSet(sd().variableMap().keySet()); + sd().addArgumentInterceptor(new ArgumentInterceptor() { + @Override + public SDVariable intercept(SDVariable argument) { + + // if its declared in the if, we don't care acout it + if(!declared2.contains(argument.getVarName())) + return argument; + + // if we've already added a switch, move on + if(switches.containsKey(argument.getVarName())) + return switches.get(argument.getVarName())[0]; + + SDVariable[] s = f().switchOp(argument, pred); + switches.put(argument.getVarName(), s); + return s[0]; + } + }); + NameScope falseScope = sd().withNameScope("falseBody"); + SDVariable falseOut = falseBody.define(sd()); + sd().removeArgumentInterceptor(); + + if(declared2.contains(falseOut.getVarName())) { + SDVariable[] s = f().switchOp(falseOut, pred); + switches.put(falseOut.getVarName(), s); + falseOut = s[0]; + } + falseScope.close(); + + SDVariable output = f().merge(trueOut, falseOut); + + ifScope.close(); + + return updateVariableNameAndReference(output, outputName); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index e131df0bc..928bf3e6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -411,6 +411,29 @@ public class SDNN extends SDOps { return updateVariableNameAndReference(ret, name); } + /** + * Log softmax activation + * + * @param x Input variable + * @return Output variable + */ + public SDVariable logSoftmax(SDVariable x, int dimension) { + return logSoftmax(null, x, dimension); + } + + /** + * Log softmax activation + * + * @param name Variable name + * @param x Input variable + * @return Output variable + */ + public SDVariable logSoftmax(String name, SDVariable x, int dimension) { + validateFloatingPoint("log softmax", x); + SDVariable ret = f().logSoftmax(x, dimension); + return updateVariableNameAndReference(ret, name); + } + /** * Element-wise rectified linear function with specified cutoff:
* out[i] = in[i] if in[i] >= cutoff @@ -591,6 +614,28 @@ public class SDNN extends SDOps { return updateVariableNameAndReference(result, name); } + /** + * Softmax activation + * + * @param x Input variable + * @return Output variable + */ + public SDVariable softmax(SDVariable x, int dimension) { + return softmax(null, x, dimension); + } + + /** + * Softmax activation + * + * @param x Input variable + * @return Output variable + */ + public SDVariable softmax(String name, SDVariable x, int dimension) { + validateFloatingPoint("softmax", x); + SDVariable result = f().softmax(x, dimension); + return updateVariableNameAndReference(result, name); + } + /** * @param x * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index eabf4fc9f..743fb527a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -17,36 +17,47 @@ package org.nd4j.autodiff.samediff.serde; import com.google.flatbuffers.FlatBufferBuilder; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.base.Preconditions; -import org.nd4j.graph.*; +import org.nd4j.graph.DataType; +import org.nd4j.graph.FlatArray; +import org.nd4j.graph.FlatNode; +import org.nd4j.graph.FlatProperties; +import org.nd4j.graph.IntPair; +import org.nd4j.graph.OpType; +import org.nd4j.graph.VarType; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.*; - +import org.nd4j.linalg.api.ops.BaseIndexAccumulation; +import org.nd4j.linalg.api.ops.BaseReduceOp; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; +import org.nd4j.linalg.api.ops.ScalarOp; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; -import java.nio.ByteOrder; -import java.util.*; - public class FlatBuffersMapper { - private FlatBuffersMapper(){ } + private FlatBuffersMapper() { + } /** * This method converts enums for DataType - * - * @param type - * @return */ public static byte getDataTypeAsByte(@NonNull org.nd4j.linalg.api.buffer.DataType type) { switch (type) { @@ -84,88 +95,87 @@ public class FlatBuffersMapper { /** * This method converts enums for DataType - * - * @param val - * @return */ public static org.nd4j.linalg.api.buffer.DataType getDataTypeFromByte(byte val) { - if (val == DataType.FLOAT) + if (val == DataType.FLOAT) { return org.nd4j.linalg.api.buffer.DataType.FLOAT; - else if (val == DataType.DOUBLE) + } else if (val == DataType.DOUBLE) { return org.nd4j.linalg.api.buffer.DataType.DOUBLE; - else if (val == DataType.HALF) - return org.nd4j.linalg.api.buffer.DataType.HALF; - else if (val == DataType.INT32) + } else if (val == DataType.HALF) { + return org.nd4j.linalg.api.buffer.DataType.HALF; + } else if (val == DataType.INT32) { return org.nd4j.linalg.api.buffer.DataType.INT; - else if (val == DataType.INT64) + } else if (val == DataType.INT64) { return org.nd4j.linalg.api.buffer.DataType.LONG; - else if (val == DataType.INT8) + } else if (val == DataType.INT8) { return org.nd4j.linalg.api.buffer.DataType.BYTE; - else if (val == DataType.BOOL) + } else if (val == DataType.BOOL) { return org.nd4j.linalg.api.buffer.DataType.BOOL; - else if (val == DataType.UINT8) + } else if (val == DataType.UINT8) { return org.nd4j.linalg.api.buffer.DataType.UBYTE; - else if (val == DataType.INT16) + } else if (val == DataType.INT16) { return org.nd4j.linalg.api.buffer.DataType.SHORT; - else if (val == DataType.UTF8) + } else if (val == DataType.UTF8) { return org.nd4j.linalg.api.buffer.DataType.UTF8; - else if (val == DataType.UINT16) + } else if (val == DataType.UINT16) { return org.nd4j.linalg.api.buffer.DataType.UINT16; - else if (val == DataType.UINT32) + } else if (val == DataType.UINT32) { return org.nd4j.linalg.api.buffer.DataType.UINT32; - else if (val == DataType.UINT64) + } else if (val == DataType.UINT64) { return org.nd4j.linalg.api.buffer.DataType.UINT64; - else + } else { throw new RuntimeException("Unknown datatype: " + val); + } } - - /** * This method return operation ID for given op name/type pair. - * - * @param name - * @param type - * @return */ public static long getOpNum(String name, Op.Type type) { if (type == Op.Type.LOOP) { return 0; } else if (type == Op.Type.RETURN) { return 40; - } else if (type == Op.Type.IF) { - return 30; } else if (type == Op.Type.CONDITIONAL) { return 10; - } else if (type == Op.Type.MERGE) { - return 60L; } else if (type == Op.Type.LOOP_COND) { return 70L; - } else if (type == Op.Type.NEXT_ITERATION) { - return 80L; - } else if (type == Op.Type.EXIT) { - return 90L; - } else if (type == Op.Type.ENTER) { - return 100L; + } else if (type == Type.LOGIC) { + switch (name) { + case Enter.OP_NAME: + return Enter.OP_NUM; + case Exit.OP_NAME: + return Exit.OP_NUM; + case NextIteration.OP_NAME: + return NextIteration.OP_NUM; + case Merge.OP_NAME: + return Merge.OP_NUM; + case Switch.OP_NAME: + return Switch.OP_NUM; + default: + throw new IllegalStateException("Unknown LOGIC op with name: " + name); + } } else if (type == Op.Type.CUSTOM) { val name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()); if (name2 == null) { val name3 = Nd4j.getExecutioner().getCustomOperations().get(name); - if (name3 == null) + if (name3 == null) { return 0; - else + } else { return name3.getHash(); - } else + } + } else { return name2.getHash(); + } //return Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase()).getHash(); } else { try { - DifferentialFunction op = DifferentialFunctionClassHolder.getInstance().getInstance(name); - return op.opNum(); + DifferentialFunction op = DifferentialFunctionClassHolder.getInstance().getInstance(name); + return op.opNum(); } catch (Exception e) { - throw new RuntimeException("Could not find op number for operation: [" + name + "]",e); + throw new RuntimeException("Could not find op number for operation: [" + name + "]", e); } } } @@ -212,7 +222,7 @@ public class FlatBuffersMapper { case OpType.RANDOM: return Op.Type.RANDOM; case OpType.LOGIC: - return Op.Type.META; + return Type.LOGIC; case OpType.CUSTOM: return Op.Type.CUSTOM; case OpType.PAIRWISE: @@ -269,15 +279,11 @@ public class FlatBuffersMapper { return OpType.INDEX_REDUCE; case RANDOM: return OpType.RANDOM; - case MERGE: case CONDITIONAL: case LOOP: case RETURN: - case ENTER: - case EXIT: - case NEXT_ITERATION: case LOOP_COND: - case IF: + case LOGIC: return OpType.LOGIC; case CUSTOM: return OpType.CUSTOM; @@ -295,88 +301,87 @@ public class FlatBuffersMapper { /** * This method just converts enums - * - * @param val - * @return */ public static ByteOrder getOrderFromByte(byte val) { - if (val == org.nd4j.graph.ByteOrder.LE) + if (val == org.nd4j.graph.ByteOrder.LE) { return ByteOrder.LITTLE_ENDIAN; - else + } else { return ByteOrder.BIG_ENDIAN; + } } /** * This method returns current byte order for this JVM as libnd4j enum - * - * @return */ public static byte getOrderAsByte() { - if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { return org.nd4j.graph.ByteOrder.BE; - else + } else { return org.nd4j.graph.ByteOrder.LE; + } } - public static DifferentialFunction fromFlatNode(FlatNode fn){ + public static DifferentialFunction fromFlatNode(FlatNode fn) { int id = fn.id(); //ID of the node String name = fn.name(); //Name of the node, NOT the name of the op Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType()); long opNum = fn.opNum(); //Op num: hash for custom, number for legacy int[] input = new int[fn.inputLength()]; - for( int i=0; i props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties)); + Map props = FlatBuffersMapper + .mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties)); - - if(opType == Op.Type.CUSTOM) { + if (opType == Op.Type.CUSTOM || opType == Type.LOGIC) { String opName = fn.opName(); + + DifferentialFunction op; Class c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName); Preconditions.checkNotNull(c, "Could not find class for hash %s", opNum); - DifferentialFunction op; try { op = (DifferentialFunction) c.newInstance(); } catch (IllegalAccessException | InstantiationException e) { throw new RuntimeException("Error creating differential function instance of type " + c); } + op.setOwnName(name); //Set input SDVariables: @@ -390,7 +395,7 @@ public class FlatBuffersMapper { op.setPropertiesForFunction(props); return op; } else { - Class c = LegacyOpMapper.getLegacyOpClassForId(opType, (int)opNum); + Class c = LegacyOpMapper.getLegacyOpClassForId(opType, (int) opNum); Op op; try { op = (Op) c.newInstance(); @@ -398,7 +403,7 @@ public class FlatBuffersMapper { throw new RuntimeException("Error creating differential function (Op) instance of type " + c); } - if(extraParams.length > 0) { + if (extraParams.length > 0) { //Assume that extraParams length 0 means extraArgs was originally null, NOT originally length 0 Object[] extraParamsObj = new Object[extraParams.length]; for (int i = 0; i < extraParams.length; i++) { @@ -406,16 +411,18 @@ public class FlatBuffersMapper { } op.setExtraArgs(extraParamsObj); } - if(opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL){ - ScalarOp sOp = (ScalarOp)op; + if (opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL) { + ScalarOp sOp = (ScalarOp) op; sOp.setScalar(scalar); - } else if(opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE - || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG || opType == Op.Type.REDUCE_SAME) { + } else if (opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS + || opType == Op.Type.VARIANCE + || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG + || opType == Op.Type.REDUCE_SAME) { val ba = (BaseReduceOp) op; //Reduce3 ops are also all BaseAccumulations ba.setDimensions(dimensions); ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions)); - } else if(opType == Op.Type.INDEXREDUCE){ - BaseIndexAccumulation bia = (BaseIndexAccumulation)op; + } else if (opType == Op.Type.INDEXREDUCE) { + BaseIndexAccumulation bia = (BaseIndexAccumulation) op; bia.setDimensions(dimensions); bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions)); } @@ -428,8 +435,8 @@ public class FlatBuffersMapper { TRANSFORM_SAME - Abs, Ceil, etc */ - ((DifferentialFunction)op).setPropertiesForFunction(props); - return (DifferentialFunction)op; + ((DifferentialFunction) op).setPropertiesForFunction(props); + return (DifferentialFunction) op; } } @@ -438,11 +445,11 @@ public class FlatBuffersMapper { private static final long[] EMPTY_LONG = new long[0]; private static final double[] EMPTY_DOUBLE = new double[0]; - public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map fnProps){ + public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map fnProps) { int[] outIdxs = new int[fnProps.size()]; int count = 0; - for(Map.Entry e : fnProps.entrySet()){ + for (Map.Entry e : fnProps.entrySet()) { //Possible types here: primitives (as Number objects), primitive arrays, Strings, String arrays, multi-dimensional string/primitives Object v = e.getValue(); int iname = fbb.createString(e.getKey()); @@ -455,13 +462,11 @@ public class FlatBuffersMapper { int[] sIdx = null; int[] shape = null; - - - if(v == null) { + if (v == null) { //No op - } else if(v instanceof Boolean){ - b = new boolean[]{(Boolean)v}; - } else if(v instanceof Number) { + } else if (v instanceof Boolean) { + b = new boolean[]{(Boolean) v}; + } else if (v instanceof Number) { if (v instanceof Double) { d = new double[]{(Double) v}; } else if (v instanceof Integer) { @@ -469,39 +474,41 @@ public class FlatBuffersMapper { } else if (v instanceof Long) { l = new long[]{(Long) v}; } else { - throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); + throw new UnsupportedOperationException( + "Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); } - } else if(v instanceof String) { + } else if (v instanceof String) { String str = (String) v; int strOffset = fbb.createString(str); sIdx = new int[]{strOffset}; - } else if(v instanceof org.nd4j.linalg.api.buffer.DataType ) { + } else if (v instanceof org.nd4j.linalg.api.buffer.DataType) { String str = v.toString(); int strOffset = fbb.createString(str); sIdx = new int[]{strOffset}; - } else if(v instanceof Enum){ + } else if (v instanceof Enum) { String str = v.toString(); int strOffset = fbb.createString(str); sIdx = new int[]{strOffset}; - } else if(v instanceof INDArray){ - INDArray arr = (INDArray)v; + } else if (v instanceof INDArray) { + INDArray arr = (INDArray) v; aIdx = new int[]{arr.toFlatArray(fbb)}; - } else if(v.getClass().isArray()){ - if(v.getClass().getComponentType().isPrimitive()){ - if(v instanceof boolean[]) { - b = (boolean[])v; + } else if (v.getClass().isArray()) { + if (v.getClass().getComponentType().isPrimitive()) { + if (v instanceof boolean[]) { + b = (boolean[]) v; shape = new int[]{b.length}; - } else if(v instanceof double[]){ - d = (double[])v; + } else if (v instanceof double[]) { + d = (double[]) v; shape = new int[]{d.length}; - } else if(v instanceof int[]){ - i = (int[])v; + } else if (v instanceof int[]) { + i = (int[]) v; shape = new int[]{i.length}; - } else if(v instanceof long[]){ - l = (long[])v; + } else if (v instanceof long[]) { + l = (long[]) v; shape = new int[]{l.length}; } else { - throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); + throw new UnsupportedOperationException( + "Unable to map property \"" + e.getKey() + "\" of type " + v.getClass()); } } else if (v instanceof String[]) { //String[] @@ -511,33 +518,35 @@ public class FlatBuffersMapper { sIdx[j] = fbb.createString(strArr[j]); } shape = new int[]{strArr.length}; - } else if (v instanceof INDArray[]){ - INDArray[] arrArr = (INDArray[])v; + } else if (v instanceof INDArray[]) { + INDArray[] arrArr = (INDArray[]) v; aIdx = new int[arrArr.length]; - for( int j=0; j mapFlatPropertiesToFunctionProperties(Iterable list){ - Map out = new HashMap<>(); - for(FlatProperties p : list){ + public static Map mapFlatPropertiesToFunctionProperties(Iterable list) { + Map out = new HashMap<>(); + for (FlatProperties p : list) { String name = p.name(); //Work out type: - if(p.shapeLength() > 0){ + if (p.shapeLength() > 0) { //Array type int[] shape = new int[p.shapeLength()]; - for( int i=0; i 0){ + if (p.iLength() > 0) { int[] iArr = new int[p.iLength()]; - for( int i=0; i 0){ + } else if (p.dLength() > 0) { double[] dArr = new double[p.dLength()]; - for( int i=0; i 0) { + } else if (p.lLength() > 0) { long[] lArr = new long[p.lLength()]; for (int i = 0; i < lArr.length; i++) { lArr[i] = p.l(i); } - if(shape.length == 0 || shape.length == 1) { + if (shape.length == 0 || shape.length == 1) { out.put(name, lArr); - } else if(shape.length == 2){ + } else if (shape.length == 2) { out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1])); - } else if(shape.length == 3){ + } else if (shape.length == 3) { out.put(name, ArrayUtil.reshapeLong(lArr, shape[0], shape[1], shape[2])); } - } else if(p.bLength() > 0){ + } else if (p.bLength() > 0) { boolean[] bArr = new boolean[p.bLength()]; - for( int i=0; i 0){ + } else if (p.sLength() > 0) { String[] sArr = new String[p.sLength()]; - for( int i=0; i 0){ + } else if (p.aLength() > 0) { INDArray[] iArr = new INDArray[p.aLength()]; - for( int i=0; i 0) { + if (p.bLength() > 0) { out.put(name, p.b(0)); - } else if(p.iLength() > 0){ + } else if (p.iLength() > 0) { out.put(name, p.i(0)); - } else if(p.lLength() > 0){ + } else if (p.lLength() > 0) { out.put(name, p.l(0)); - } else if(p.dLength() > 0){ + } else if (p.dLength() > 0) { out.put(name, p.d(0)); - } else if(p.sLength() > 0){ + } else if (p.sLength() > 0) { out.put(name, p.s(0)); - } else if(p.aLength() > 0){ + } else if (p.aLength() > 0) { FlatArray fa = p.a(0); out.put(name, Nd4j.createFromFlatArray(fa)); } else { @@ -673,8 +683,8 @@ public class FlatBuffersMapper { return out; } - public static byte toVarType(VariableType variableType){ - switch (variableType){ + public static byte toVarType(VariableType variableType) { + switch (variableType) { case VARIABLE: return VarType.VARIABLE; case CONSTANT: @@ -688,8 +698,8 @@ public class FlatBuffersMapper { } } - public static VariableType fromVarType(byte varType){ - switch (varType){ + public static VariableType fromVarType(byte varType) { + switch (varType) { case VarType.VARIABLE: return VariableType.VARIABLE; case VarType.CONSTANT: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index 1f550a8aa..3e1772bd9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -126,12 +126,7 @@ public class LegacyOpMapper { case CONDITIONAL: case LOOP: case LOOP_COND: - case IF: case RETURN: - case ENTER: - case EXIT: - case NEXT_ITERATION: - case MERGE: default: throw new UnsupportedOperationException("Unable to map op " + opNum + " of type " + opType); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 42c0cecfc..c7464448a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -25,6 +25,11 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser; import org.nd4j.imports.descriptors.onnx.OpDescriptor; import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.linalg.api.ops.*; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; @@ -331,13 +336,27 @@ public class DifferentialFunctionClassHolder { } public Class customOpClassForHashAndName(long customOpHash, String name){ - if(customOpHashToClasses.containsKey(customOpHash)){ - return customOpHashToClasses.get(customOpHash).get(name); - } else if(customOpHashToClass.containsKey(customOpHash)){ - return customOpHashToClass.get(customOpHash); - } else { - throw new IllegalStateException("No op known for hash: " + customOpHash); + switch (name) { + case Enter.OP_NAME: + return Enter.class; + case Exit.OP_NAME: + return Exit.class; + case NextIteration.OP_NAME: + return NextIteration.class; + case Merge.OP_NAME: + return Merge.class; + case Switch.OP_NAME: + return Switch.class; + default: + if(customOpHashToClasses.containsKey(customOpHash)){ + return customOpHashToClasses.get(customOpHash).get(name); + } else if(customOpHashToClass.containsKey(customOpHash)){ + return customOpHashToClass.get(customOpHash); + } else { + throw new IllegalStateException("No op known for hash: " + customOpHash); + } } + } public static DifferentialFunctionClassHolder getInstance() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java index 14549b049..3e5644439 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/Op.java @@ -69,14 +69,10 @@ public interface Op { CONDITIONAL, LOOP, LOOP_COND, - IF, RETURN, - ENTER, - EXIT, - NEXT_ITERATION, RANDOM, - MERGE, SUMMARYSTATS, + LOGIC } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java index 54c5f2fe0..85a94eb13 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java @@ -17,11 +17,13 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; import lombok.Data; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -32,13 +34,38 @@ import java.util.List; import java.util.Map; @Data +@NoArgsConstructor public class Enter extends BaseCompatOp { protected boolean isConstant; + public Enter(SameDiff sameDiff, SDVariable[] inputs){ + super(sameDiff, inputs); + } + + public Enter(SameDiff sameDiff, String frameName, SDVariable input){ + super(sameDiff, new SDVariable[]{input}); + this.frameName = frameName; + isConstant = input.isConstant(); + } + + public Enter(SameDiff sameDiff, String frameName, SDVariable input, boolean isConstant){ + super(sameDiff, new SDVariable[]{input}); + this.frameName = frameName; + this.isConstant = isConstant; + } + + /** + * WARNING: do not change without changing serialization methods + * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)} + * and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)} + */ + public static final String OP_NAME = "enter"; + public static final int OP_NUM = 100; + @Override public String opName() { - return "enter"; + return OP_NAME; } @Override @@ -62,7 +89,7 @@ public class Enter extends BaseCompatOp { @Override public Op.Type opType() { - return Op.Type.ENTER; + return Type.LOGIC; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java index 9bed4af8e..f9e358f3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; +import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -34,10 +36,24 @@ import java.util.Collections; import java.util.List; import java.util.Map; +@NoArgsConstructor public class Exit extends BaseCompatOp { + + public Exit(SameDiff sameDiff, SDVariable x) { + super(sameDiff, new SDVariable[]{x}); + } + + /** + * WARNING: do not change without changing serialization methods + * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)} + * and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)} + */ + public static final String OP_NAME = "exit"; + public static final int OP_NUM = 90; + @Override public String opName() { - return "exit"; + return OP_NAME; } @Override @@ -61,7 +77,7 @@ public class Exit extends BaseCompatOp { @Override public Op.Type opType() { - return Op.Type.EXIT; + return Type.LOGIC; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 7600def65..386f4a075 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,9 +42,21 @@ public class Merge extends BaseCompatOp { } + /** + * WARNING: do not change without changing serialization methods + * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)} + * and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)} + */ + public static final String OP_NAME = "merge"; + public static final int OP_NUM = 60; + + public Merge(SameDiff sd, SDVariable a, SDVariable b){ + this(sd, new SDVariable[]{a, b}); + } + @Override public String opName() { - return "merge"; + return OP_NAME; } @Override @@ -72,7 +85,7 @@ public class Merge extends BaseCompatOp { @Override public Op.Type opType() { - return Op.Type.MERGE; + return Type.LOGIC; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java index 28ac283be..fabd0479b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java @@ -16,11 +16,13 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -31,10 +33,24 @@ import java.util.Collections; import java.util.List; import java.util.Map; +@NoArgsConstructor public class NextIteration extends BaseCompatOp { + + public NextIteration(SameDiff sameDiff, SDVariable x) { + super(sameDiff, new SDVariable[]{x}); + } + + /** + * WARNING: do not change without changing serialization methods + * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)} + * and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)} + */ + public static final String OP_NAME = "next_iteration"; + public static final int OP_NUM = 80; + @Override public String opName() { - return "next_iteration"; + return OP_NAME; } @Override @@ -58,7 +74,7 @@ public class NextIteration extends BaseCompatOp { @Override public Op.Type opType() { - return Op.Type.NEXT_ITERATION; + return Type.LOGIC; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 7a79f1911..77145a625 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; +import com.google.common.collect.Lists; +import lombok.Getter; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.Op.Type; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -37,15 +40,27 @@ import java.util.Map; */ public class Switch extends BaseCompatOp { + @Getter + private SDVariable predicate; + public Switch(SameDiff sameDiff, SDVariable input, SDVariable predicate){ super(sameDiff, new SDVariable[]{input, predicate}); + this.predicate = predicate; } public Switch(){ } + /** + * WARNING: do not change without changing serialization methods + * See {@link org.nd4j.autodiff.samediff.serde.FlatBuffersMapper#getOpNum(String, Type)} + * and {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder#customOpClassForHashAndName(long, String)} + */ + public static final String OP_NAME = "switch"; + public static final int OP_NUM = 30; + @Override public String opName() { - return "switch"; + return OP_NAME; } @Override @@ -72,7 +87,7 @@ public class Switch extends BaseCompatOp { @Override public Op.Type opType() { - return Op.Type.IF; + return Type.LOGIC; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index fb5dd2af2..b13870606 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -39,6 +39,9 @@ import java.util.List; */ public class LogSoftMax extends DynamicCustomOp { + + private Integer dimension = null; + public LogSoftMax(SameDiff sameDiff, SDVariable i_v) { super(sameDiff, i_v); } @@ -54,6 +57,12 @@ public class LogSoftMax extends DynamicCustomOp { this(x, x); } + public LogSoftMax(SameDiff sameDiff, SDVariable i_v, int dimension) { + this(sameDiff, i_v); + this.dimension = dimension; + addIArgument(dimension); + } + @Override public String opName() { @@ -66,8 +75,13 @@ public class LogSoftMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); - return Collections.singletonList(ret); + if(dimension == null) { + SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); + return Collections.singletonList(ret); + } else { + SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension); + return Collections.singletonList(ret); + } } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LogSoftMaxDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LogSoftMaxDerivative.java index fd68f99f3..6be0367fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LogSoftMaxDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LogSoftMaxDerivative.java @@ -43,6 +43,11 @@ public class LogSoftMaxDerivative extends DynamicCustomOp { super(null, new INDArray[]{in, gradO}, new INDArray[]{out}); } + public LogSoftMaxDerivative(SameDiff sameDiff, SDVariable arg, SDVariable wrt, int dimension) { + this(sameDiff, arg, wrt); + this.addIArgument(dimension); + } + /** * The opName of this operation * diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index c5ca3aa13..3b1f57d33 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -129,4 +129,39 @@ public class NameScopeTests extends BaseNd4jTest { } } } + + @Test + public void testNoNesting(){ + SameDiff SD = SameDiff.create(); + + SDVariable a = SD.constant(4); + + NameScope scope = SD.withNameScope("test"); + + SDVariable out = SD.argmax(a); + + out.add(45); + + scope.close(); + + assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1")); + } + + @Test + public void testNoTesting2(){ + SameDiff SD = SameDiff.create(); + + SDVariable a = SD.constant(4); + SDVariable b = SD.constant(5).lt(4); + + NameScope scope = SD.withNameScope("test"); + + SDVariable out = SD.f().switchOp(a, b)[0]; + + out.add(45); + + scope.close(); + + assertTrue("Var with name test/switch:1 exists", SD.variableMap().containsKey("test/switch:1")); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 3408053d2..167e490a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -16,12 +16,30 @@ package org.nd4j.autodiff.samediff; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeNotNull; +import static org.nd4j.linalg.indexing.NDArrayIndex.all; + import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.nd4j.OpValidationSuite; @@ -43,7 +61,11 @@ import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; -import org.nd4j.linalg.api.ops.impl.transforms.custom.*; +import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; +import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; +import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -53,9 +75,7 @@ import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOpsHolder; @@ -63,29 +83,20 @@ import org.nd4j.weightinit.impl.OneInitScheme; import org.nd4j.weightinit.impl.UniformInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme; -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.lang.reflect.Field; -import java.util.*; - -import static org.junit.Assert.*; -import static org.junit.Assume.assumeNotNull; -import static org.nd4j.linalg.indexing.NDArrayIndex.all; - /** * Created by agibsonccc on 4/11/17. */ @Slf4j public class SameDiffTests extends BaseNd4jTest { + private DataType initialType; - public SameDiffTests(Nd4jBackend b){ + public SameDiffTests(Nd4jBackend b) { super(b); } @Override - public char ordering(){ + public char ordering() { return 'c'; } @@ -317,7 +328,6 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff first = SameDiff.create(); SameDiff second = SameDiff.create(); - SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable secondVar = second.var(firstVar); assertTrue(firstVar.getArr() == secondVar.getArr()); @@ -330,7 +340,6 @@ public class SameDiffTests extends BaseNd4jTest { SameDiff first = SameDiff.create(); SameDiff second = SameDiff.create(); - SDVariable firstVar = first.var("one", new long[]{2, 2}); SDVariable secondVar = second.var(firstVar); assumeNotNull(firstVar.getArr()); @@ -418,7 +427,6 @@ public class SameDiffTests extends BaseNd4jTest { } }, xAndY); - INDArray assertionForDiv = Nd4j.valueArrayOf(4, 4.0); INDArray assertionForRDiv = Nd4j.valueArrayOf(4, 0.25); assertEquals(assertionForDiv, sameDiff.getFunction("div").execAndEndResult()); @@ -463,7 +471,8 @@ public class SameDiffTests extends BaseNd4jTest { }, inputs); INDArray assertion = sumInput.sum(1); - INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")).get("sum"); + INDArray out = sameDiff.getFunction("sum").exec(Collections.emptyMap(), Collections.singletonList("sum")) + .get("sum"); assertEquals(assertion, out); } @@ -563,7 +572,6 @@ public class SameDiffTests extends BaseNd4jTest { } }, inputVars); - //1 input plus 2 outputs assertEquals(3, functionDef.variables().size()); @@ -573,7 +581,8 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testIfStatementTrueBodyBackwards() { - OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations + OpValidationSuite + .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations SameDiff sameDiff = SameDiff.create(); SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { @Override @@ -584,7 +593,6 @@ public class SameDiffTests extends BaseNd4jTest { } }; - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { @@ -607,7 +615,6 @@ public class SameDiffTests extends BaseNd4jTest { }; - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); sameDiff.execBackwards(Collections.emptyMap()); SameDiff grad = sameDiff.getFunction("grad"); @@ -625,7 +632,8 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testIfStatementTrueBody() { - OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations + OpValidationSuite + .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations SameDiff sameDiff = SameDiff.create(); SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { @@ -637,7 +645,6 @@ public class SameDiffTests extends BaseNd4jTest { } }; - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { @@ -660,7 +667,6 @@ public class SameDiffTests extends BaseNd4jTest { }; - sameDiff.ifStatement(new DefaultSameDiffConditional(), conditionBody, trueBody, falseBody, firstInputs); sameDiff.exec(Collections.emptyMap()); } @@ -668,7 +674,8 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testIfStatementFalseBody() { - OpValidationSuite.ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations + OpValidationSuite + .ignoreFailing(); //2019/01/14 AB: Disabled pending overhaul of SameDiff-defined conditional operations SameDiff sameDiff = SameDiff.create(); SameDiffFunctionDefinition conditionBody = new SameDiffFunctionDefinition() { @@ -680,7 +687,6 @@ public class SameDiffTests extends BaseNd4jTest { } }; - SameDiffFunctionDefinition trueBody = new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { @@ -697,7 +703,6 @@ public class SameDiffTests extends BaseNd4jTest { } }; - //false body trigger SDVariable[] secondInputs = new SDVariable[]{ sameDiff.setupFunction(sameDiff.var("two", new long[]{1, 1})) @@ -790,7 +795,6 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable weights = sd.var("W", new long[]{nIn, nOut}); SDVariable bias = sd.var("b", new long[]{1, nOut}); - SDVariable mmul = sd.mmul("mmul", input, weights); SDVariable z = mmul.add("z", bias); SDVariable out = sd.math().tanh(z); @@ -888,7 +892,6 @@ public class SameDiffTests extends BaseNd4jTest { val f = m.add(2.0); val s = in2.add(5.0); - val arr = sd.execSingle(null, s.getVarName()); log.info("Result M: {}", m.getArr()); log.info("Result F: {}", f.getArr()); @@ -939,7 +942,8 @@ public class SameDiffTests extends BaseNd4jTest { val vector = Nd4j.linspace(1, 4, 4).reshape(4, 1); val input1 = sd.var("input", matrix); val input2 = sd.var("input2", vector); - val output = sd.mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); + val output = sd + .mmul("output", input1, input2, MMulTranspose.builder().transposeA(true).transposeB(false).build()); output.eval(); assertArrayEquals(new long[]{3, 1}, output.getShape()); } @@ -1026,12 +1030,11 @@ public class SameDiffTests extends BaseNd4jTest { } }, inputs); - SameDiff logisticGraph = sameDiffOuter.getFunction("oneminuspredictions"); Map inputsSubset = new HashMap<>(); inputsSubset.put("y", inputs.get("y")); INDArray output = logisticGraph.exec(inputsSubset, Collections.singletonList("rsub")).get("rsub"); - INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4,1}); + INDArray assertion = Nd4j.create(new double[]{0, 0, 1, 0}, new int[]{4, 1}); assertEquals(assertion, output); } @@ -1076,7 +1079,6 @@ public class SameDiffTests extends BaseNd4jTest { } }, inputs); - SameDiff logisticPrediction = sameDiffOuter.getFunction("logisticPredictions"); List logisticOpNameAssertions = Arrays.asList("mmul", "sigmoid"); @@ -1146,7 +1148,8 @@ public class SameDiffTests extends BaseNd4jTest { Activation.SOFTPLUS, Activation.SOFTSIGN, Activation.HARDTANH, - Activation.CUBE, //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426 + Activation.CUBE, + //WRONG output - see issue https://github.com/deeplearning4j/nd4j/issues/2426 Activation.RELU, //JVM crash Activation.LEAKYRELU //JVM crash }; @@ -1289,8 +1292,9 @@ public class SameDiffTests extends BaseNd4jTest { sd.exec(Collections.emptyMap(), sd.outputs()); - for (int i = 0; i < 4; i++) + for (int i = 0; i < 4; i++) { assertEquals(1, out.getArr().get(all(), NDArrayIndex.point(i), all(), all()).getInt(0)); + } } @@ -1327,7 +1331,6 @@ public class SameDiffTests extends BaseNd4jTest { INDArray means = Nd4j.create(new float[]{2, 4}, new long[]{1, 2}); INDArray vars = Nd4j.create(new float[]{6, 8}, new long[]{1, 2}); - SDVariable sdCounts = sd.var("counts", counts); SDVariable sdMeans = sd.var("means", means); SDVariable sdVars = sd.var("vars", vars); @@ -1363,7 +1366,6 @@ public class SameDiffTests extends BaseNd4jTest { int imgH = 28; int imgW = 28; - SameDiff sd = SameDiff.create(); INDArray depthWeightArr = Nd4j.create(kH, kW, nIn, depthWise); @@ -1720,7 +1722,6 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in1 = sd.var("in1", ia); SDVariable in2 = sd.var("in2", ib); - SDVariable t; INDArray expOut; switch (i) { @@ -1835,7 +1836,8 @@ public class SameDiffTests extends BaseNd4jTest { val origShape = new long[]{3, 4}; for (int i = 0; i < 3; i++) { - for (Pair p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { + for (Pair p : NDArrayCreationUtil + .getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { INDArray inArr = p.getFirst().muli(100); SameDiff sd = SameDiff.create(); @@ -1875,7 +1877,8 @@ public class SameDiffTests extends BaseNd4jTest { val shape = origShape.clone(); shape[i] = 1; - for (Pair p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { + for (Pair p : NDArrayCreationUtil + .getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { INDArray inArr = p.getFirst().muli(100); SameDiff sd = SameDiff.create(); @@ -1912,7 +1915,8 @@ public class SameDiffTests extends BaseNd4jTest { val origShape = new long[]{3, 4}; for (int i = 0; i < 3; i++) { - for (Pair p : NDArrayCreationUtil.getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { + for (Pair p : NDArrayCreationUtil + .getAllTestMatricesWithShape(origShape[0], origShape[1], 12345, DataType.FLOAT)) { INDArray inArr = p.getFirst().muli(100); SameDiff sd = SameDiff.create(); @@ -1939,7 +1943,8 @@ public class SameDiffTests extends BaseNd4jTest { val shape = origShape.clone(); shape[i] = 1; - for (Pair p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { + for (Pair p : NDArrayCreationUtil + .getAll3dTestArraysWithShape(12345, shape, DataType.FLOAT)) { INDArray inArr = p.getFirst().muli(100); SameDiff sd = SameDiff.create(); @@ -2214,7 +2219,6 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable in = sd.var("in", 1, 2); sd.associateArrayWithVariable(ia, in); - INDArray expFinite = Nd4j.create(new boolean[]{true, true}); SDVariable finite = sd.math().isFinite(in); @@ -2259,11 +2263,10 @@ public class SameDiffTests extends BaseNd4jTest { SDVariable result2 = x.get(SDIndex.point(4), SDIndex.all()); assertEquals(expOut2, result2.eval()); - INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)).reshape(5,10); + INDArray expOut3 = arr.get(NDArrayIndex.interval(3, 8)).reshape(5, 10); SDVariable result3 = x.get(SDIndex.interval(3, 8)); assertEquals(expOut3, result3.eval()); - INDArray expOut4 = arr.get(NDArrayIndex.point(5), NDArrayIndex.interval(3, 8)).reshape(5); SDVariable result4 = x.get(SDIndex.point(5), SDIndex.interval(3, 8)); assertEquals(expOut4, result4.eval()); @@ -2295,7 +2298,6 @@ public class SameDiffTests extends BaseNd4jTest { INDArray s3a = s3.eval(); assertEquals(s3a, y3); - INDArray y4 = arr.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.interval(3, 5)); SDVariable s4 = x.get(SDIndex.point(2), SDIndex.all(), SDIndex.interval(3, 5)); INDArray s4a = s4.eval(); @@ -2409,7 +2411,6 @@ public class SameDiffTests extends BaseNd4jTest { }, new int[]{3, 2, 4}); - SDVariable x = sd.var(arr); SDVariable result = sd.permute(x, 1, 0, 2); assertEquals(expOut, result.eval()); @@ -2470,7 +2471,7 @@ public class SameDiffTests extends BaseNd4jTest { ExternalErrorsFunction fn = sd.f().externalErrors(out); sd.execAndEndResult(); - Map m = new HashMap<>(); + Map m = new HashMap<>(); m.put("out-grad", externalGrad); sd.execBackwards(m); @@ -2488,7 +2489,6 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(externalGrad.mul(0.5), gradVar); - //Test model serialization: } @@ -2620,7 +2620,7 @@ public class SameDiffTests extends BaseNd4jTest { b.setArray(bA); INDArray grad = Nd4j.linspace(1, 12, 12, DataType.FLOAT).reshape(3, 4); - Map phMap = new HashMap<>(); + Map phMap = new HashMap<>(); phMap.put(fn.getGradPlaceholderName(), grad); log.info("--------------- sd.execAndEndResult() ---------------"); @@ -2723,7 +2723,6 @@ public class SameDiffTests extends BaseNd4jTest { .build(); sd.setTrainingConfig(c); - sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1); INDArray out = tanh.eval(); @@ -2757,7 +2756,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3); in.setArray(inArr); - INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3,4); + INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4); TrainingConfig c = TrainingConfig.builder() .updater(new Adam(0.1)) @@ -2767,7 +2766,6 @@ public class SameDiffTests extends BaseNd4jTest { .build(); sd.setTrainingConfig(c); - sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1); INDArray out = tanh.eval(); @@ -2859,7 +2857,6 @@ public class SameDiffTests extends BaseNd4jTest { } final INDArray out = Nd4j.concat(2, output).norm2(); - SameDiff sd = SameDiff.create(); final SDVariable sdInput = sd.var("input", input); @@ -2905,7 +2902,6 @@ public class SameDiffTests extends BaseNd4jTest { } final INDArray out = Nd4j.concat(2, output).norm2(); - SameDiff sd = SameDiff.create(); final SDVariable sdInput = sd.var("input", input); @@ -2917,13 +2913,11 @@ public class SameDiffTests extends BaseNd4jTest { outputSlices[0] = x_0; outputSlices[0] = sd.expandDims("X_0-e", outputSlices[0], 2); - final val x_1 = inputSlices[1]; outputSlices[1] = x_1; outputSlices[1] = outputSlices[1].add(sd.squeeze("X_0-s", outputSlices[0], 2)); outputSlices[1] = sd.expandDims("X_1-e", outputSlices[1], 2); - SDVariable t = sd.concat(2, outputSlices); t.norm2("out"); String err = OpValidation.validate(new TestCase(sd) @@ -3036,7 +3030,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffBackprop1(){ + public void testSameDiffBackprop1() { SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", Nd4j.rand(4, 4)); final SDVariable b = sd.var("b", Nd4j.rand(4, 4)); @@ -3050,7 +3044,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffNoGradForConstantAndPlaceholder(){ + public void testSameDiffNoGradForConstantAndPlaceholder() { SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", Nd4j.rand(4, 4)); final SDVariable b = sd.constant("b", Nd4j.rand(4, 4)); @@ -3058,16 +3052,16 @@ public class SameDiffTests extends BaseNd4jTest { a.add(b.add(c)).sum().markAsLoss(); - sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4,4 ))); + sd.execBackwards(Collections.singletonMap("c", Nd4j.rand(4, 4))); assertNotNull(sd.grad("a")); assertNull(sd.grad("b")); assertNull(sd.grad("c")); } @Test - public void testDuplicateNamePlaceholder(){ + public void testDuplicateNamePlaceholder() { - for( int i=0; i<2; i++ ) { + for (int i = 0; i < 2; i++) { SameDiff sd = SameDiff.create(); SDVariable x1 = i == 0 ? sd.placeHolder("a", DataType.FLOAT, 5, 3) : sd.var("a", DataType.FLOAT, 5, 3); SDVariable x2 = i == 0 ? sd.placeHolder("b", DataType.FLOAT, 5, 3) : sd.var("b", DataType.FLOAT, 5, 3); @@ -3119,7 +3113,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testSameDiffGetArrayScalar(){ + public void testSameDiffGetArrayScalar() { final INDArray array = Nd4j.rand(1, 1); final SameDiff sd = SameDiff.create(); final SDVariable a = sd.var("a", array.shape()); @@ -3128,11 +3122,11 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testVariableRenaming(){ + public void testVariableRenaming() { SameDiff sd = SameDiff.create(); - SDVariable v1 = sd.var("x", Nd4j.rand(DataType.FLOAT, 3,4)); - SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4,5)); + SDVariable v1 = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4)); + SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5)); SDVariable v3 = v1.mmul("oldName", v2); INDArray out = sd.execSingle(null, "oldName"); @@ -3150,11 +3144,11 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testVariableRenaming2(){ + public void testVariableRenaming2() { SameDiff sd = SameDiff.create(); - SDVariable v1 = sd.placeHolder("x", DataType.FLOAT,3,4); - SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4,5)); + SDVariable v1 = sd.placeHolder("x", DataType.FLOAT, 3, 4); + SDVariable v2 = sd.var("y", Nd4j.rand(DataType.FLOAT, 4, 5)); SDVariable v3 = v1.mmul("oldName", v2); SDVariable v4 = v3.std("out", false); @@ -3172,7 +3166,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testPlaceholderShapeValidation(){ + public void testPlaceholderShapeValidation() { SameDiff sd = SameDiff.create(); SDVariable ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4); SDVariable ph2 = sd.placeHolder("ph2", DataType.FLOAT, -1, 4); @@ -3183,33 +3177,36 @@ public class SameDiffTests extends BaseNd4jTest { INDArray wrongShape = Nd4j.create(DataType.FLOAT, 2, 3); INDArray wrongRank1 = Nd4j.create(DataType.FLOAT, 1); INDArray wrongRank2 = Nd4j.create(DataType.FLOAT, 3, 4, 5); - for(SDVariable v : new SDVariable[]{ph1, ph2, ph3, ph4}){ + for (SDVariable v : new SDVariable[]{ph1, ph2, ph3, ph4}) { v.setArray(correctShape); - if(v != ph4) { + if (v != ph4) { try { v.setArray(wrongShape); fail("Expected exception"); } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg.contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]") && msg + .contains(Arrays.toString(v.placeholderShape()))); } } - try{ + try { v.setArray(wrongRank1); fail("Expected exception"); - } catch (Exception t){ + } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg.contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg, msg.contains("shape") && msg.contains("[1]") && msg + .contains(Arrays.toString(v.placeholderShape()))); } - try{ + try { v.setArray(wrongRank2); fail("Expected exception"); - } catch (Exception t){ + } catch (Exception t) { String msg = t.getMessage(); - assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg.contains(Arrays.toString(v.placeholderShape()))); + assertTrue(msg, msg.contains("shape") && msg.contains("[3, 4, 5]") && msg + .contains(Arrays.toString(v.placeholderShape()))); } } @@ -3223,9 +3220,9 @@ public class SameDiffTests extends BaseNd4jTest { .markLabelsUnused() .updater(new Adam(1e-3)).build()); - try{ + try { sd.fit(mds); - } catch (Exception t){ + } catch (Exception t) { String msg = t.getMessage(); assertTrue(msg, msg.contains("shape") && msg.contains("[2, 3]")); } @@ -3233,7 +3230,7 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testInferenceWithoutLabel(){ + public void testInferenceWithoutLabel() { //We don't need a value for the label placeholder to calculate most values here SameDiff sd = SameDiff.create(); @@ -3252,15 +3249,14 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out = m.get("softmax"); - INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3); - Map allPh = new HashMap<>(); + Map allPh = new HashMap<>(); allPh.put("in", inputArr); allPh.put("label", labelUnused); m = sd.exec(allPh, "softmax"); @@ -3271,7 +3267,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testInferenceWithoutUnnecessaryPlaceholders(){ + public void testInferenceWithoutUnnecessaryPlaceholders() { //We don't need an array for 2 of the placeholders to calculate the SameDiff sd = SameDiff.create(); @@ -3293,15 +3289,14 @@ public class SameDiffTests extends BaseNd4jTest { INDArray inputArr = Nd4j.rand(DataType.FLOAT, minibatch, nIn); - Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); + Map m = sd.exec(Collections.singletonMap("in", inputArr), "softmax"); assertEquals(1, m.size()); assertTrue(m.containsKey("softmax")); INDArray out = m.get("softmax"); - INDArray labelUnused = Nd4j.rand(DataType.FLOAT, minibatch, 3); - Map allPh = new HashMap<>(); + Map allPh = new HashMap<>(); allPh.put("in", inputArr); allPh.put("label", labelUnused); allPh.put("in2", Nd4j.scalar(1.0f)); @@ -3314,7 +3309,7 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testConvertDTypes1(){ + public void testConvertDTypes1() { SameDiff sd = SameDiff.create(); SDVariable x = sd.var("x", Nd4j.rand(DataType.FLOAT, 3, 4)); @@ -3329,15 +3324,15 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.FLOAT, tanh.dataType()); assertEquals(DataType.FLOAT, stdev.dataType()); - Map out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); - for(Map.Entry e : out.entrySet()){ + Map out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); + for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); } assertEquals(DataType.FLOAT, x.getArr().dataType()); assertEquals(DataType.FLOAT, y.getArr().dataType()); - Map toConvert = new HashMap<>(); + Map toConvert = new HashMap<>(); toConvert.put("x", DataType.DOUBLE); toConvert.put("y", DataType.DOUBLE); sd.convertDataTypes(toConvert); @@ -3349,7 +3344,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, stdev.dataType()); out = sd.exec(null, "x", "y", "z", "tanh", "stdev"); - for(Map.Entry e : out.entrySet()){ + for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3358,7 +3353,7 @@ public class SameDiffTests extends BaseNd4jTest { } @Test - public void testConvertDTypes2(){ + public void testConvertDTypes2() { SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.FLOAT, 3, 4); @@ -3375,11 +3370,11 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, add.dataType()); assertEquals(DataType.DOUBLE, relu.dataType()); - Map ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)); + Map ph = Collections.singletonMap("x", Nd4j.rand(DataType.FLOAT, 3, 4)); - Map out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); - for(Map.Entry e : out.entrySet()){ - if(e.getKey().equals("x") || e.getKey().equals("y")){ + Map out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); + for (Map.Entry e : out.entrySet()) { + if (e.getKey().equals("x") || e.getKey().equals("y")) { assertEquals(e.getKey(), DataType.FLOAT, e.getValue().dataType()); } else { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); @@ -3388,7 +3383,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.FLOAT, y.getArr().dataType()); - Map toConvert = new HashMap<>(); + Map toConvert = new HashMap<>(); toConvert.put("x", DataType.DOUBLE); toConvert.put("y", DataType.DOUBLE); sd.convertDataTypes(toConvert); @@ -3401,7 +3396,7 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(DataType.DOUBLE, relu.dataType()); out = sd.exec(ph, "x", "y", "xD", "yD", "a", "r"); - for(Map.Entry e : out.entrySet()){ + for (Map.Entry e : out.entrySet()) { assertEquals(e.getKey(), DataType.DOUBLE, e.getValue().dataType()); } @@ -3410,11 +3405,11 @@ public class SameDiffTests extends BaseNd4jTest { @Test - public void testGradFnRequiredVars(){ + public void testGradFnRequiredVars() { //User can explicitly request that gradients for specific vars are available when differentiating (creating grad function), // even if they normally wouldn't be needed or calculated - for(boolean reqPhVar : new boolean[]{false, true}){ + for (boolean reqPhVar : new boolean[]{false, true}) { // for(boolean reqPhVar : new boolean[]{true}){ SameDiff sd = SameDiff.create(); @@ -3429,7 +3424,7 @@ public class SameDiffTests extends BaseNd4jTest { INDArray in = Nd4j.rand(DataType.FLOAT, 1, 5); - if(reqPhVar){ + if (reqPhVar) { sd.createGradFunction("in"); assertNotNull(ph.gradient()); assertNotNull(w.gradient()); @@ -3447,6 +3442,129 @@ public class SameDiffTests extends BaseNd4jTest { } + } + + @Test + public void testIf() throws IOException { + SameDiff SD = SameDiff.create(); + SDVariable a = SD.placeHolder("a", DataType.DOUBLE); + SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); + + SDVariable output = SD.ifCond("out", null, (sd) -> a.lt(b), (sd) -> c, (sd) -> c.add(5)); + + Map firstBranch = Maps.newHashMap(); + firstBranch.put("a", Nd4j.createFromArray(3.0)); + assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); + + Map secondBranch = Maps.newHashMap(); + secondBranch.put("a", Nd4j.createFromArray(7.0)); + assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); + + //TODO complains that it can't deserialize a meta type, but there are no meta type ops here + // looks like a difference between Op.Type and OpType. Switch is saved as a OpType.LOGIC + SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + + assertEquals(Nd4j.createFromArray(9.0), SD.exec(firstBranch, "out").get("out")); + assertEquals(Nd4j.createFromArray(14.0), SD.exec(secondBranch, "out").get("out")); + + + } + + @Test + public void testNestedIf() throws IOException { + SameDiff SD = SameDiff.create(); + SDVariable a = SD.var("a", Nd4j.createFromArray(2.0)); + SDVariable b = SD.var("b", Nd4j.createFromArray(5.0)); + SDVariable c = SD.var("c", Nd4j.createFromArray(9.0)); + SDVariable d = SD.var("d", Nd4j.createFromArray(-7.0)); + + SDVariable output = SD.ifCond("out", null, + (sd) -> a.lt(b), + (sd) -> sd.ifCond( + (sd2) -> d.lte(0), + (sd2) -> c.add(1), + (sd2) -> d), + (sd) -> c.add(5)); + INDArray out = output.eval(); + assertEquals(Nd4j.createFromArray(10.0), out); + + SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + + assertEquals(Nd4j.createFromArray(10.0), SD.exec(null, "out").get("out")); + } + + @Test + public void testWhile() throws IOException { + + SameDiff SD = SameDiff.create(); + SDVariable countIn = SD.constant(5); + SDVariable sumIn = SD.constant(0); + + SDVariable[] sum = SD.whileLoop("while_1", new SDVariable[]{countIn, sumIn}, + (sd, vars) -> vars[0].gt(0), + (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add(vars[0])}); + + INDArray out = sum[1].eval(); + assertEquals(15, out.getInt(0)); + + String outName = sum[1].getVarName(); + + SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + + assertEquals(15, SD.exec(null, outName).get(outName).getInt(0)); + } + + @Test + @Ignore + public void testNestedWhile() throws IOException { + SameDiff SD = SameDiff.create(); + SDVariable countIn = SD.constant(5); + SDVariable sumIn = SD.constant(0); + SDVariable sum2 = SD.constant(0); + //TODO creating constant instead of using sum2 causes errors + + SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, + (sd, vars) -> vars[0].gt(0), + (sd, vars) -> new SDVariable[]{vars[0].sub(1), + vars[1].add(sd.whileLoop(new SDVariable[]{vars[0], sum2}, + (sd2, vars2) -> vars2[0].gt(0), + (sd2, vars2) -> new SDVariable[]{vars2[0].sub(1), vars2[1].add(vars2[0])})[1])}); + + INDArray out = sum[1].eval(); + assertEquals(35, out.getInt(0)); + + String outName = sum[1].getVarName(); + + SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + + assertEquals(35, SD.exec(null, outName).get(outName).getInt(0)); + + } + + @Test + public void testNestedWhileIf() throws IOException { + SameDiff SD = SameDiff.create(); + SDVariable countIn = SD.constant(5); + SDVariable sumIn = SD.constant(0); + SDVariable hundred = SD.constant(100); + + SDVariable[] sum = SD.whileLoop(new SDVariable[]{countIn, sumIn}, + (sd, vars) -> vars[0].gte(0), + (sd, vars) -> new SDVariable[]{vars[0].sub(1), vars[1].add( + sd.ifCond((sd2) -> vars[0].eq(0), + (sd2) -> vars[0].add(100), //TODO replace with hundred and things break + (sd2) -> vars[0]) + )}); + + INDArray out = sum[1].eval(); + assertEquals(115, out.getInt(0)); + + String outName = sum[1].getVarName(); + + SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); + + assertEquals(115, SD.exec(null, outName).get(outName).getInt(0)); } }