diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java
index 10e3989c4..3a3caa787 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java
@@ -19,6 +19,9 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -35,6 +38,10 @@ public class AbsoluteDifferenceLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
+ public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
+ super(lossReduce, predictions, weights, labels);
+ }
+
public AbsoluteDifferenceLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java
index f3bd728f5..9794c7c8b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java
@@ -22,7 +22,9 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@@ -38,6 +40,16 @@ public abstract class BaseLoss extends DynamicCustomOp {
addArgs();
}
+ public BaseLoss(@NonNull LossReduce lossReduce, @NonNull INDArray predictions, INDArray weights, @NonNull INDArray labels){
+ super(new INDArray[]{predictions, getWeights(weights, predictions), labels}, null);
+ this.lossReduce = lossReduce;
+ addArgs();
+ }
+
+ protected static INDArray getWeights(INDArray weights, INDArray predictions){
+ return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
+ }
+
protected BaseLoss(){ }
protected void addArgs(){
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java
index b68449c7e..241404492 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -38,6 +39,12 @@ public class CosineDistanceLoss extends BaseLoss {
this.addIArgument(dimension);
}
+ public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){
+ super(lossReduce, predictions, weights, labels);
+ this.dimension = dimension;
+ this.addIArgument(dimension);
+ }
+
public CosineDistanceLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java
index 67d59ce54..f2998064f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -35,6 +36,10 @@ public class HingeLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
+ public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
+ super(lossReduce, predictions, weights, labels);
+ }
+
public HingeLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java
index ff2d3ebda..18803cd9f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java
@@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -40,6 +41,12 @@ public class HuberLoss extends BaseLoss {
tArguments.add(delta);
}
+ public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){
+ super(lossReduce, predictions, weights, labels);
+ this.delta = delta;
+ tArguments.add(delta);
+ }
+
public HuberLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
index ee168e75e..e1fe56e5f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
@@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
@@ -36,6 +37,10 @@ public class L2Loss extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{var});
}
+ public L2Loss(INDArray var){
+ super(new INDArray[]{var}, null);
+ }
+
@Override
public String opName() {
return "l2_loss";
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java
index 83458df02..01aa283ed 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -40,6 +41,13 @@ public class LogLoss extends BaseLoss {
addTArgument(epsilon);
}
+ public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){
+ super(lossReduce, predictions, weights, labels);
+ this.epsilon = epsilon;
+ addTArgument(epsilon);
+ }
+
+
public LogLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java
index 72e275455..0e0d4f7dd 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -43,6 +44,12 @@ public class LogPoissonLoss extends BaseLoss {
addArgs();
}
+ public LogPoissonLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, boolean full){
+ super(lossReduce, predictions, weights, labels);
+ this.full = full;
+ addArgs();
+ }
+
public LogPoissonLoss(){ }
protected void addArgs(){
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java
index 6f33f3209..8e7bb9276 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -33,6 +34,10 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
+ public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
+ super(lossReduce, predictions, weights, labels);
+ }
+
public MeanPairwiseSquaredErrorLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java
index ddab8b2a4..c38faf29a 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java
@@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@@ -35,6 +36,10 @@ public class MeanSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
+ public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
+ super(lossReduce, predictions, weights, labels);
+ }
+
public MeanSquaredErrorLoss(){ }
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
index b2823e712..32b176cfd 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
@@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@@ -54,6 +55,12 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
this(sameDiff, reductionMode, logits, weights, labels, 0.0);
}
+ public SigmoidCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
+ super(lossReduce, predictions, weights, labels);
+ this.labelSmoothing = labelSmoothing;
+ addArgs();
+ }
+
public void addArgs() {
super.addArgs();
addTArgument(labelSmoothing);
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java
index 21756d99b..c8a40b805 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java
@@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@@ -56,6 +57,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
this(sameDiff, lossReduce, logits, weights, labels, 0.0);
}
+ public SoftmaxCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
+ super(lossReduce, predictions, weights, labels);
+ this.labelSmoothing = labelSmoothing;
+ addArgs();
+ }
public void addArgs() {
super.addArgs();
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
index 1d922a1e6..a0f3288a9 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
@@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.loss;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@@ -43,10 +45,13 @@ import java.util.*;
@NoArgsConstructor
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
- public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
+ public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull SameDiff sameDiff, @NonNull SDVariable logits, @NonNull SDVariable labels) {
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}
+ public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull INDArray logits, @NonNull INDArray labels){
+ super(new INDArray[]{labels, logits}, null);
+ }
public void addArgs() {
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java
index cf7c8b8a8..2ee16ff22 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/WeightedCrossEntropyLoss.java
@@ -17,11 +17,13 @@
package org.nd4j.linalg.api.ops.impl.loss;
import lombok.NoArgsConstructor;
+import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
@@ -43,6 +45,9 @@ public class WeightedCrossEntropyLoss extends DynamicCustomOp {
this.sameDiff = sameDiff;
}
+ public WeightedCrossEntropyLoss(@NonNull INDArray targets, @NonNull INDArray inputs, @NonNull INDArray weights){
+ super(new INDArray[] {targets, inputs, weights}, null);
+ }
@Override
public String opName() {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
index 4a07d6c7a..7fe936250 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java
@@ -16,10 +16,7 @@
package org.nd4j.linalg.factory;
-import org.nd4j.linalg.factory.ops.NDBitwise;
-import org.nd4j.linalg.factory.ops.NDMath;
-import org.nd4j.linalg.factory.ops.NDNN;
-import org.nd4j.linalg.factory.ops.NDRandom;
+import org.nd4j.linalg.factory.ops.*;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import lombok.NonNull;
@@ -134,6 +131,11 @@ public class Nd4j {
*/
public static final NDNN nn = new NDNN();
+ /**
+ * Loss function namespace - operations related to loss functions.
+ */
+ public static final NDLoss loss = new NDLoss();
+
/**
* Bitwise namespace - operations related to bitwise manipulation of arrays
*/
@@ -162,6 +164,11 @@ public class Nd4j {
return nn;
}
+ /**
+ * Loss function namespace - operations related to loss functions.
+ */
+ public static NDLoss loss(){ return loss; }
+
private final static String DATA_BUFFER_OPS = "databufferfactory";
private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java
new file mode 100644
index 000000000..4c1234514
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java
@@ -0,0 +1,483 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
+
+package org.nd4j.linalg.factory.ops;
+
+import static org.nd4j.linalg.factory.NDValidation.isSameType;
+
+import org.nd4j.autodiff.loss.LossReduce;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.NDValidation;
+import org.nd4j.linalg.factory.Nd4j;
+
+public class NDLoss {
+ public NDLoss() {
+ }
+
+ /**
+ * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @return output loss variable (NUMERIC type)
+ */
+ public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce) {
+ NDValidation.validateNumerical("absoluteDifference", "label", label);
+ NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
+ NDValidation.validateNumerical("absoluteDifference", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, lossReduce))[0];
+ }
+
+ /**
+ * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output loss variable (NUMERIC type)
+ */
+ public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights) {
+ NDValidation.validateNumerical("absoluteDifference", "label", label);
+ NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
+ NDValidation.validateNumerical("absoluteDifference", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
+ }
+
+ /**
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param dimension Dimension to perform the cosine distance over
+ * @return output Cosine distance loss (NUMERIC type)
+ */
+ public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce, int dimension) {
+ NDValidation.validateNumerical("cosineDistance", "label", label);
+ NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
+ NDValidation.validateNumerical("cosineDistance", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, lossReduce, dimension))[0];
+ }
+
+ /**
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
+ * @param dimension Dimension to perform the cosine distance over
+ * @return output Cosine distance loss (NUMERIC type)
+ */
+ public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
+ int dimension) {
+ NDValidation.validateNumerical("cosineDistance", "label", label);
+ NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
+ NDValidation.validateNumerical("cosineDistance", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
+ }
+
+ /**
+ * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ *
+ * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce) {
+ NDValidation.validateNumerical("hingeLoss", "label", label);
+ NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
+ NDValidation.validateNumerical("hingeLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, lossReduce))[0];
+ }
+
+ /**
+ * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ *
+ * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights) {
+ NDValidation.validateNumerical("hingeLoss", "label", label);
+ NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
+ NDValidation.validateNumerical("hingeLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
+ }
+
+ /**
+ * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *
+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param delta Loss function delta value
+ * @return output Huber loss (NUMERIC type)
+ */
+ public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce, double delta) {
+ NDValidation.validateNumerical("huberLoss", "label", label);
+ NDValidation.validateNumerical("huberLoss", "predictions", predictions);
+ NDValidation.validateNumerical("huberLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, lossReduce, delta))[0];
+ }
+
+ /**
+ * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *
+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param delta Loss function delta value
+ * @return output Huber loss (NUMERIC type)
+ */
+ public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, double delta) {
+ NDValidation.validateNumerical("huberLoss", "label", label);
+ NDValidation.validateNumerical("huberLoss", "predictions", predictions);
+ NDValidation.validateNumerical("huberLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta))[0];
+ }
+
+ /**
+ * L2 loss: 1/2 * sum(x^2)
+ *
+ * @param var Variable to calculate L2 loss of (NUMERIC type)
+ * @return output L2 loss (NUMERIC type)
+ */
+ public INDArray l2Loss(INDArray var) {
+ NDValidation.validateNumerical("l2Loss", "var", var);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.L2Loss(var))[0];
+ }
+
+ /**
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param epsilon epsilon
+ * @return output Log loss (NUMERIC type)
+ */
+ public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce, double epsilon) {
+ NDValidation.validateNumerical("logLoss", "label", label);
+ NDValidation.validateNumerical("logLoss", "predictions", predictions);
+ NDValidation.validateNumerical("logLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, lossReduce, epsilon))[0];
+ }
+
+ /**
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param epsilon epsilon
+ * @return output Log loss (NUMERIC type)
+ */
+ public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) {
+ NDValidation.validateNumerical("logLoss", "label", label);
+ NDValidation.validateNumerical("logLoss", "predictions", predictions);
+ NDValidation.validateNumerical("logLoss", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0];
+ }
+
+ /**
+ * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ *
+ * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
+ * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param full Boolean flag. true for logPoissonFull, false for logPoisson
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce, boolean full) {
+ NDValidation.validateNumerical("logPoisson", "label", label);
+ NDValidation.validateNumerical("logPoisson", "predictions", predictions);
+ NDValidation.validateNumerical("logPoisson", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, lossReduce, full))[0];
+ }
+
+ /**
+ * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ *
+ * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
+ * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param full Boolean flag. true for logPoissonFull, false for logPoisson
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, boolean full) {
+ NDValidation.validateNumerical("logPoisson", "label", label);
+ NDValidation.validateNumerical("logPoisson", "predictions", predictions);
+ NDValidation.validateNumerical("logPoisson", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full))[0];
+ }
+
+ /**
+ * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @return output Loss variable, scalar output (NUMERIC type)
+ */
+ public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce) {
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
+ }
+
+ /**
+ * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
+ * @return output Loss variable, scalar output (NUMERIC type)
+ */
+ public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights) {
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
+ NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
+ }
+
+ /**
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
+ LossReduce lossReduce) {
+ NDValidation.validateNumerical("meanSquaredError", "label", label);
+ NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
+ NDValidation.validateNumerical("meanSquaredError", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
+ }
+
+ /**
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictions Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights) {
+ NDValidation.validateNumerical("meanSquaredError", "label", label);
+ NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
+ NDValidation.validateNumerical("meanSquaredError", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
+ }
+
+ /**
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *
+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictionLogits Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param labelSmoothing Label smoothing value. Default value: 0
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights,
+ LossReduce lossReduce, double labelSmoothing) {
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, lossReduce, labelSmoothing))[0];
+ }
+
+ /**
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *
+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *
+ *
+ * @param label Label array (NUMERIC type)
+ * @param predictionLogits Predictions array (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights) {
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
+ NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
+ }
+
+ /**
+ * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *
+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *
+ *
+ * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
+ * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
+ * @param labelSmoothing Label smoothing value. Default value: 0
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
+ INDArray weights, LossReduce lossReduce, double labelSmoothing) {
+ NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
+ NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
+ NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing))[0];
+ }
+
+ /**
+ * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *
+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *
+ *
+ * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
+ * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
+ * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
+ INDArray weights) {
+ NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
+ NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
+ NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
+ }
+
+ /**
+ * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ *
+ * @param logits Logits array ("pre-softmax activations") (NUMERIC type)
+ * @param labels Labels array. Must be an integer type. (INT type)
+ * @return output Softmax cross entropy (NUMERIC type)
+ */
+ public INDArray sparseSoftmaxCrossEntropy(INDArray logits, INDArray labels) {
+ NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits);
+ NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(logits, labels))[0];
+ }
+
+ /**
+ * Weighted cross entropy loss with logits
+ *
+ * @param targets targets array (NUMERIC type)
+ * @param inputs input array (NUMERIC type)
+ * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
+ * @return output Loss variable (NUMERIC type)
+ */
+ public INDArray weightedCrossEntropyWithLogits(INDArray targets, INDArray inputs,
+ INDArray weights) {
+ NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets);
+ NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs);
+ NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights);
+ return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(targets, inputs, weights))[0];
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java
new file mode 100644
index 000000000..40d32121d
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java
@@ -0,0 +1,453 @@
+/* *****************************************************************************
+ * Copyright (c) 2019-2020 Konduit k.k.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.nd4j.linalg.factory.ops;
+
+import org.junit.Test;
+import org.nd4j.autodiff.loss.LossReduce;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.linalg.BaseNd4jTest;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.factory.Nd4jBackend;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+public class NDLossTest extends BaseNd4jTest {
+ public NDLossTest(Nd4jBackend backend) {
+ super(backend);
+ }
+
+ @Override
+ public char ordering(){
+ return 'c';
+ }
+
+ @Test
+ public void testAbsoluteDifference() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+
+ SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction);
+ SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, wArr, reduction);
+ INDArray y2 = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, null, reduction);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testCosineDistance() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ predictionsArr.diviColumnVector(predictionsArr.norm2(1));
+ labelsArr.diviColumnVector(labelsArr.norm2(1));
+
+ SDVariable loss = sd.loss().cosineDistance("loss", labels, predictions, w, reduction, 0);
+ SDVariable loss2 = sd.loss().cosineDistance("loss2", labels, predictions, null, reduction, 0);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, wArr, reduction, 0);
+ INDArray y2 = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, null, reduction, 0);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testHingeLoss() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().hingeLoss("loss", labels, predictions, w, reduction);
+ SDVariable loss2 = sd.loss().hingeLoss("loss2", labels, predictions, null, reduction);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, wArr, reduction);
+ INDArray y2 = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, null, reduction);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testHuberLoss() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().huberLoss("loss", labels, predictions, w, reduction, 0.02);
+ SDVariable loss2 = sd.loss().huberLoss("loss2", labels, predictions, null, reduction, 0.02);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().huberLoss(labelsArr, predictionsArr, wArr, reduction, 0.02);
+ INDArray y2 = Nd4j.loss().huberLoss(labelsArr, predictionsArr, null, reduction, 0.02);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testL2Loss() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().l2Loss("loss", predictions);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+
+ INDArray y_exp = loss.eval();
+
+ INDArray y = Nd4j.loss().l2Loss(predictionsArr);
+ assertEquals(y_exp, y);
+ }
+
+ @Test
+ public void testLogLoss() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5));
+ predictionsArr = Nd4j.rand(predictionsArr.shape()).muli(0.8).addi(0.1);
+
+ double eps = 1e-7;
+
+ SDVariable loss = sd.loss().logLoss("loss", labels, predictions, w, reduction, eps);
+ SDVariable loss2 = sd.loss().logLoss("loss2", labels, predictions, null, reduction, eps);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ //TODO: Test fails. "Op [log_loss] execution failed"
+ INDArray y = Nd4j.loss().logLoss(labelsArr, predictionsArr, wArr, reduction, eps);
+ INDArray y2 = Nd4j.loss().logLoss(labelsArr, predictionsArr, null, reduction, eps);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testLogPoisson() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction);
+ SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().logPoisson(labelsArr, predictionsArr, wArr, reduction, false);
+ INDArray y2 = Nd4j.loss().logPoisson(labelsArr, predictionsArr, null, reduction, false);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testMeanPairwiseSquaredError() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction);
+ SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, wArr, reduction);
+ INDArray y2 = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, null, reduction);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testMeanSquaredError() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+
+ SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction);
+ SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, wArr, reduction);
+ INDArray y2 = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, null, reduction);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testSigmoidCrossEntropy() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.create(new double[][]{
+ {0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
+ {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
+ SDVariable w = sd.var("weights", wArr);
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ double labelSmoothing = 0.01;
+
+ SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
+ SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
+ INDArray y2 = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testSoftmaxCrossEntropy() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+
+ INDArray wArr = Nd4j.scalar(1.0); //TODO: This test fails with a complex weights array.
+ SDVariable w = null;
+
+ LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ labelsArr.assign(0);
+ for (int i = 0; i < labelsArr.size(0); i++) {
+ labelsArr.putScalar(i, i % labelsArr.size(1), 1.0);
+ }
+
+ double labelSmoothing = 0.0;
+
+ SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
+ SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
+ sd.associateArrayWithVariable(predictionsArr, predictions);
+ sd.associateArrayWithVariable(labelsArr, labels);
+
+ INDArray y_exp = loss.eval();
+ INDArray y_exp2 = loss2.eval();
+
+ INDArray y = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
+ INDArray y2 = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
+ assertEquals(y_exp, y);
+ assertEquals(y_exp2, y2);
+ }
+
+ @Test
+ public void testSparseSoftmaxCrossEntropy() {
+ SameDiff sd = SameDiff.create();
+
+ int nOut = 4;
+ int minibatch = 10;
+ SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
+ SDVariable labels = sd.var("labels", DataType.INT32, -1);
+
+
+ INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
+ INDArray labelsArr = Nd4j.create(DataType.INT32, minibatch);
+ for( int i=0; i