diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index 10e3989c4..3a3caa787 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -19,6 +19,9 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -35,6 +38,10 @@ public class AbsoluteDifferenceLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + super(lossReduce, predictions, weights, labels); + } + public AbsoluteDifferenceLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java index f3bd728f5..9794c7c8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java @@ -22,7 +22,9 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; @@ -38,6 +40,16 @@ public abstract class BaseLoss extends DynamicCustomOp { addArgs(); } + public BaseLoss(@NonNull LossReduce lossReduce, @NonNull INDArray predictions, INDArray weights, @NonNull INDArray labels){ + super(new INDArray[]{predictions, getWeights(weights, predictions), labels}, null); + this.lossReduce = lossReduce; + addArgs(); + } + + protected static INDArray getWeights(INDArray weights, INDArray predictions){ + return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0); + } + protected BaseLoss(){ } protected void addArgs(){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index b68449c7e..241404492 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -38,6 +39,12 @@ public class CosineDistanceLoss extends BaseLoss { this.addIArgument(dimension); } + public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){ + super(lossReduce, predictions, weights, labels); + this.dimension = dimension; + this.addIArgument(dimension); + } + public CosineDistanceLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index 67d59ce54..f2998064f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -35,6 +36,10 @@ public class HingeLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + super(lossReduce, predictions, weights, labels); + } + public HingeLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index ff2d3ebda..18803cd9f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -40,6 +41,12 @@ public class HuberLoss extends BaseLoss { tArguments.add(delta); } + public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){ + super(lossReduce, predictions, weights, labels); + this.delta = delta; + tArguments.add(delta); + } + public HuberLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index ee168e75e..e1fe56e5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -36,6 +37,10 @@ public class L2Loss extends DynamicCustomOp { super(sameDiff, new SDVariable[]{var}); } + public L2Loss(INDArray var){ + super(new INDArray[]{var}, null); + } + @Override public String opName() { return "l2_loss"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index 83458df02..01aa283ed 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -40,6 +41,13 @@ public class LogLoss extends BaseLoss { addTArgument(epsilon); } + public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){ + super(lossReduce, predictions, weights, labels); + this.epsilon = epsilon; + addTArgument(epsilon); + } + + public LogLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index 72e275455..0e0d4f7dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -43,6 +44,12 @@ public class LogPoissonLoss extends BaseLoss { addArgs(); } + public LogPoissonLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, boolean full){ + super(lossReduce, predictions, weights, labels); + this.full = full; + addArgs(); + } + public LogPoissonLoss(){ } protected void addArgs(){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 6f33f3209..8e7bb9276 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -33,6 +34,10 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + super(lossReduce, predictions, weights, labels); + } + public MeanPairwiseSquaredErrorLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index ddab8b2a4..c38faf29a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Arrays; import java.util.List; @@ -35,6 +36,10 @@ public class MeanSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ + super(lossReduce, predictions, weights, labels); + } + public MeanSquaredErrorLoss(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index b2823e712..32b176cfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -54,6 +55,12 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { this(sameDiff, reductionMode, logits, weights, labels, 0.0); } + public SigmoidCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){ + super(lossReduce, predictions, weights, labels); + this.labelSmoothing = labelSmoothing; + addArgs(); + } + public void addArgs() { super.addArgs(); addTArgument(labelSmoothing); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 21756d99b..c8a40b805 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -56,6 +57,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { this(sameDiff, lossReduce, logits, weights, labels, 0.0); } + public SoftmaxCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){ + super(lossReduce, predictions, weights, labels); + this.labelSmoothing = labelSmoothing; + addArgs(); + } public void addArgs() { super.addArgs(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 1d922a1e6..a0f3288a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.loss; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; @@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -43,10 +45,13 @@ import java.util.*; @NoArgsConstructor public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { - public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) { + public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull SameDiff sameDiff, @NonNull SDVariable logits, @NonNull SDVariable labels) { super(null, sameDiff, new SDVariable[]{labels, logits}, false); } + public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull INDArray logits, @NonNull INDArray labels){ + super(new INDArray[]{labels, logits}, null); + } public void addArgs() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java index cf7c8b8a8..2ee16ff22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java @@ -17,11 +17,13 @@ package org.nd4j.linalg.api.ops.impl.loss; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; @@ -43,6 +45,9 @@ public class WeightedCrossEntropyLoss extends DynamicCustomOp { this.sameDiff = sameDiff; } + public WeightedCrossEntropyLoss(@NonNull INDArray targets, @NonNull INDArray inputs, @NonNull INDArray weights){ + super(new INDArray[] {targets, inputs, weights}, null); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 4a07d6c7a..7fe936250 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,10 +16,7 @@ package org.nd4j.linalg.factory; -import org.nd4j.linalg.factory.ops.NDBitwise; -import org.nd4j.linalg.factory.ops.NDMath; -import org.nd4j.linalg.factory.ops.NDNN; -import org.nd4j.linalg.factory.ops.NDRandom; +import org.nd4j.linalg.factory.ops.*; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; @@ -134,6 +131,11 @@ public class Nd4j { */ public static final NDNN nn = new NDNN(); + /** + * Loss function namespace - operations related to loss functions. + */ + public static final NDLoss loss = new NDLoss(); + /** * Bitwise namespace - operations related to bitwise manipulation of arrays */ @@ -162,6 +164,11 @@ public class Nd4j { return nn; } + /** + * Loss function namespace - operations related to loss functions. + */ + public static NDLoss loss(){ return loss; } + private final static String DATA_BUFFER_OPS = "databufferfactory"; private final static String CONVOLUTION_OPS = "convops"; /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java new file mode 100644 index 000000000..4c1234514 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java @@ -0,0 +1,483 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.autodiff.loss.LossReduce; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDLoss { + public NDLoss() { + } + + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output loss variable (NUMERIC type) + */ + public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce) { + NDValidation.validateNumerical("absoluteDifference", "label", label); + NDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + NDValidation.validateNumerical("absoluteDifference", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, lossReduce))[0]; + } + + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output loss variable (NUMERIC type) + */ + public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights) { + NDValidation.validateNumerical("absoluteDifference", "label", label); + NDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + NDValidation.validateNumerical("absoluteDifference", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0]; + } + + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce, int dimension) { + NDValidation.validateNumerical("cosineDistance", "label", label); + NDValidation.validateNumerical("cosineDistance", "predictions", predictions); + NDValidation.validateNumerical("cosineDistance", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, lossReduce, dimension))[0]; + } + + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights, + int dimension) { + NDValidation.validateNumerical("cosineDistance", "label", label); + NDValidation.validateNumerical("cosineDistance", "predictions", predictions); + NDValidation.validateNumerical("cosineDistance", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0]; + } + + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce) { + NDValidation.validateNumerical("hingeLoss", "label", label); + NDValidation.validateNumerical("hingeLoss", "predictions", predictions); + NDValidation.validateNumerical("hingeLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, lossReduce))[0]; + } + + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights) { + NDValidation.validateNumerical("hingeLoss", "label", label); + NDValidation.validateNumerical("hingeLoss", "predictions", predictions); + NDValidation.validateNumerical("hingeLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0]; + } + + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce, double delta) { + NDValidation.validateNumerical("huberLoss", "label", label); + NDValidation.validateNumerical("huberLoss", "predictions", predictions); + NDValidation.validateNumerical("huberLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, lossReduce, delta))[0]; + } + + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, double delta) { + NDValidation.validateNumerical("huberLoss", "label", label); + NDValidation.validateNumerical("huberLoss", "predictions", predictions); + NDValidation.validateNumerical("huberLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta))[0]; + } + + /** + * L2 loss: 1/2 * sum(x^2)
+ * + * @param var Variable to calculate L2 loss of (NUMERIC type) + * @return output L2 loss (NUMERIC type) + */ + public INDArray l2Loss(INDArray var) { + NDValidation.validateNumerical("l2Loss", "var", var); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.L2Loss(var))[0]; + } + + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce, double epsilon) { + NDValidation.validateNumerical("logLoss", "label", label); + NDValidation.validateNumerical("logLoss", "predictions", predictions); + NDValidation.validateNumerical("logLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, lossReduce, epsilon))[0]; + } + + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) { + NDValidation.validateNumerical("logLoss", "label", label); + NDValidation.validateNumerical("logLoss", "predictions", predictions); + NDValidation.validateNumerical("logLoss", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0]; + } + + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce, boolean full) { + NDValidation.validateNumerical("logPoisson", "label", label); + NDValidation.validateNumerical("logPoisson", "predictions", predictions); + NDValidation.validateNumerical("logPoisson", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, lossReduce, full))[0]; + } + + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, boolean full) { + NDValidation.validateNumerical("logPoisson", "label", label); + NDValidation.validateNumerical("logPoisson", "predictions", predictions); + NDValidation.validateNumerical("logPoisson", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full))[0]; + } + + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable, scalar output (NUMERIC type) + */ + public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce) { + NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, lossReduce))[0]; + } + + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @return output Loss variable, scalar output (NUMERIC type) + */ + public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights) { + NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0]; + } + + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights, + LossReduce lossReduce) { + NDValidation.validateNumerical("meanSquaredError", "label", label); + NDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + NDValidation.validateNumerical("meanSquaredError", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, lossReduce))[0]; + } + + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights) { + NDValidation.validateNumerical("meanSquaredError", "label", label); + NDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + NDValidation.validateNumerical("meanSquaredError", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0]; + } + + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights, + LossReduce lossReduce, double labelSmoothing) { + NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, lossReduce, labelSmoothing))[0]; + } + + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights) { + NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0]; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions, + INDArray weights, LossReduce lossReduce, double labelSmoothing) { + NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing))[0]; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions, + INDArray weights) { + NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0]; + } + + /** + * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ * + * @param logits Logits array ("pre-softmax activations") (NUMERIC type) + * @param labels Labels array. Must be an integer type. (INT type) + * @return output Softmax cross entropy (NUMERIC type) + */ + public INDArray sparseSoftmaxCrossEntropy(INDArray logits, INDArray labels) { + NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); + NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(logits, labels))[0]; + } + + /** + * Weighted cross entropy loss with logits
+ * + * @param targets targets array (NUMERIC type) + * @param inputs input array (NUMERIC type) + * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public INDArray weightedCrossEntropyWithLogits(INDArray targets, INDArray inputs, + INDArray weights) { + NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); + NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); + NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(targets, inputs, weights))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java new file mode 100644 index 000000000..40d32121d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java @@ -0,0 +1,453 @@ +/* ***************************************************************************** + * Copyright (c) 2019-2020 Konduit k.k. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.factory.ops; + +import org.junit.Test; +import org.nd4j.autodiff.loss.LossReduce; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class NDLossTest extends BaseNd4jTest { + public NDLossTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering(){ + return 'c'; + } + + @Test + public void testAbsoluteDifference() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + + SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction); + SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, wArr, reduction); + INDArray y2 = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, null, reduction); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testCosineDistance() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + predictionsArr.diviColumnVector(predictionsArr.norm2(1)); + labelsArr.diviColumnVector(labelsArr.norm2(1)); + + SDVariable loss = sd.loss().cosineDistance("loss", labels, predictions, w, reduction, 0); + SDVariable loss2 = sd.loss().cosineDistance("loss2", labels, predictions, null, reduction, 0); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, wArr, reduction, 0); + INDArray y2 = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, null, reduction, 0); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testHingeLoss() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().hingeLoss("loss", labels, predictions, w, reduction); + SDVariable loss2 = sd.loss().hingeLoss("loss2", labels, predictions, null, reduction); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, wArr, reduction); + INDArray y2 = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, null, reduction); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testHuberLoss() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().huberLoss("loss", labels, predictions, w, reduction, 0.02); + SDVariable loss2 = sd.loss().huberLoss("loss2", labels, predictions, null, reduction, 0.02); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().huberLoss(labelsArr, predictionsArr, wArr, reduction, 0.02); + INDArray y2 = Nd4j.loss().huberLoss(labelsArr, predictionsArr, null, reduction, 0.02); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testL2Loss() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().l2Loss("loss", predictions); + sd.associateArrayWithVariable(predictionsArr, predictions); + + INDArray y_exp = loss.eval(); + + INDArray y = Nd4j.loss().l2Loss(predictionsArr); + assertEquals(y_exp, y); + } + + @Test + public void testLogLoss() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5)); + predictionsArr = Nd4j.rand(predictionsArr.shape()).muli(0.8).addi(0.1); + + double eps = 1e-7; + + SDVariable loss = sd.loss().logLoss("loss", labels, predictions, w, reduction, eps); + SDVariable loss2 = sd.loss().logLoss("loss2", labels, predictions, null, reduction, eps); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + //TODO: Test fails. "Op [log_loss] execution failed" + INDArray y = Nd4j.loss().logLoss(labelsArr, predictionsArr, wArr, reduction, eps); + INDArray y2 = Nd4j.loss().logLoss(labelsArr, predictionsArr, null, reduction, eps); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testLogPoisson() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction); + SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().logPoisson(labelsArr, predictionsArr, wArr, reduction, false); + INDArray y2 = Nd4j.loss().logPoisson(labelsArr, predictionsArr, null, reduction, false); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testMeanPairwiseSquaredError() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction); + SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, wArr, reduction); + INDArray y2 = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, null, reduction); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testMeanSquaredError() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + + SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction); + SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, wArr, reduction); + INDArray y2 = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, null, reduction); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testSigmoidCrossEntropy() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.create(new double[][]{ + {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1}, + {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}}); + SDVariable w = sd.var("weights", wArr); + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + double labelSmoothing = 0.01; + + SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); + SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing); + INDArray y2 = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testSoftmaxCrossEntropy() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + + INDArray wArr = Nd4j.scalar(1.0); //TODO: This test fails with a complex weights array. + SDVariable w = null; + + LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT; + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + labelsArr.assign(0); + for (int i = 0; i < labelsArr.size(0); i++) { + labelsArr.putScalar(i, i % labelsArr.size(1), 1.0); + } + + double labelSmoothing = 0.0; + + SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); + SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); + sd.associateArrayWithVariable(predictionsArr, predictions); + sd.associateArrayWithVariable(labelsArr, labels); + + INDArray y_exp = loss.eval(); + INDArray y_exp2 = loss2.eval(); + + INDArray y = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing); + INDArray y2 = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing); + assertEquals(y_exp, y); + assertEquals(y_exp2, y2); + } + + @Test + public void testSparseSoftmaxCrossEntropy() { + SameDiff sd = SameDiff.create(); + + int nOut = 4; + int minibatch = 10; + SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut); + SDVariable labels = sd.var("labels", DataType.INT32, -1); + + + INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); + INDArray labelsArr = Nd4j.create(DataType.INT32, minibatch); + for( int i=0; i