Loss namespace (#294)
* codegen for SDLoss. WIP. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * first pass of SDLoss. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip. Firsat cut of new op constructors. UNTESTED , NOT COMPILED YET. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * updated op signatures. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * add NDLoss tests. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * fix test. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * adds loss default params. factory. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Regenerate NDLoss Signed-off-by: AlexDBlack <blacka101@gmail.com> * adds tests for null weights. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Last few tweaks Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Robert Altena <Rob@Ra-ai.com>master
parent
7494117e90
commit
e6a7b94fe4
|
@ -19,6 +19,9 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -35,6 +38,10 @@ public class AbsoluteDifferenceLoss extends BaseLoss {
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
}
|
||||||
|
|
||||||
public AbsoluteDifferenceLoss(){ }
|
public AbsoluteDifferenceLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -22,7 +22,9 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -38,6 +40,16 @@ public abstract class BaseLoss extends DynamicCustomOp {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BaseLoss(@NonNull LossReduce lossReduce, @NonNull INDArray predictions, INDArray weights, @NonNull INDArray labels){
|
||||||
|
super(new INDArray[]{predictions, getWeights(weights, predictions), labels}, null);
|
||||||
|
this.lossReduce = lossReduce;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static INDArray getWeights(INDArray weights, INDArray predictions){
|
||||||
|
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
protected BaseLoss(){ }
|
protected BaseLoss(){ }
|
||||||
|
|
||||||
protected void addArgs(){
|
protected void addArgs(){
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -38,6 +39,12 @@ public class CosineDistanceLoss extends BaseLoss {
|
||||||
this.addIArgument(dimension);
|
this.addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.addIArgument(dimension);
|
||||||
|
}
|
||||||
|
|
||||||
public CosineDistanceLoss(){ }
|
public CosineDistanceLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -35,6 +36,10 @@ public class HingeLoss extends BaseLoss {
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
}
|
||||||
|
|
||||||
public HingeLoss(){ }
|
public HingeLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -40,6 +41,12 @@ public class HuberLoss extends BaseLoss {
|
||||||
tArguments.add(delta);
|
tArguments.add(delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.delta = delta;
|
||||||
|
tArguments.add(delta);
|
||||||
|
}
|
||||||
|
|
||||||
public HuberLoss(){ }
|
public HuberLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -36,6 +37,10 @@ public class L2Loss extends DynamicCustomOp {
|
||||||
super(sameDiff, new SDVariable[]{var});
|
super(sameDiff, new SDVariable[]{var});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public L2Loss(INDArray var){
|
||||||
|
super(new INDArray[]{var}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "l2_loss";
|
return "l2_loss";
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -40,6 +41,13 @@ public class LogLoss extends BaseLoss {
|
||||||
addTArgument(epsilon);
|
addTArgument(epsilon);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
addTArgument(epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public LogLoss(){ }
|
public LogLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -43,6 +44,12 @@ public class LogPoissonLoss extends BaseLoss {
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public LogPoissonLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, boolean full){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.full = full;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public LogPoissonLoss(){ }
|
public LogPoissonLoss(){ }
|
||||||
|
|
||||||
protected void addArgs(){
|
protected void addArgs(){
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -33,6 +34,10 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss {
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
}
|
||||||
|
|
||||||
public MeanPairwiseSquaredErrorLoss(){ }
|
public MeanPairwiseSquaredErrorLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
import org.nd4j.autodiff.loss.LossReduce;
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -35,6 +36,10 @@ public class MeanSquaredErrorLoss extends BaseLoss {
|
||||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
}
|
||||||
|
|
||||||
public MeanSquaredErrorLoss(){ }
|
public MeanSquaredErrorLoss(){ }
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -54,6 +55,12 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
|
||||||
this(sameDiff, reductionMode, logits, weights, labels, 0.0);
|
this(sameDiff, reductionMode, logits, weights, labels, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SigmoidCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.labelSmoothing = labelSmoothing;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public void addArgs() {
|
public void addArgs() {
|
||||||
super.addArgs();
|
super.addArgs();
|
||||||
addTArgument(labelSmoothing);
|
addTArgument(labelSmoothing);
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -56,6 +57,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
|
||||||
this(sameDiff, lossReduce, logits, weights, labels, 0.0);
|
this(sameDiff, lossReduce, logits, weights, labels, 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SoftmaxCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
|
||||||
|
super(lossReduce, predictions, weights, labels);
|
||||||
|
this.labelSmoothing = labelSmoothing;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
public void addArgs() {
|
public void addArgs() {
|
||||||
super.addArgs();
|
super.addArgs();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.loss;
|
package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -43,10 +45,13 @@ import java.util.*;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
|
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
|
||||||
|
|
||||||
public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
|
public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull SameDiff sameDiff, @NonNull SDVariable logits, @NonNull SDVariable labels) {
|
||||||
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
|
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull INDArray logits, @NonNull INDArray labels){
|
||||||
|
super(new INDArray[]{labels, logits}, null);
|
||||||
|
}
|
||||||
|
|
||||||
public void addArgs() {
|
public void addArgs() {
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,13 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.loss;
|
package org.nd4j.linalg.api.ops.impl.loss;
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
|
||||||
|
@ -43,6 +45,9 @@ public class WeightedCrossEntropyLoss extends DynamicCustomOp {
|
||||||
this.sameDiff = sameDiff;
|
this.sameDiff = sameDiff;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public WeightedCrossEntropyLoss(@NonNull INDArray targets, @NonNull INDArray inputs, @NonNull INDArray weights){
|
||||||
|
super(new INDArray[] {targets, inputs, weights}, null);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -16,10 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.factory;
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.ops.NDBitwise;
|
import org.nd4j.linalg.factory.ops.*;
|
||||||
import org.nd4j.linalg.factory.ops.NDMath;
|
|
||||||
import org.nd4j.linalg.factory.ops.NDNN;
|
|
||||||
import org.nd4j.linalg.factory.ops.NDRandom;
|
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
import org.nd4j.shade.guava.primitives.Longs;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
@ -134,6 +131,11 @@ public class Nd4j {
|
||||||
*/
|
*/
|
||||||
public static final NDNN nn = new NDNN();
|
public static final NDNN nn = new NDNN();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loss function namespace - operations related to loss functions.
|
||||||
|
*/
|
||||||
|
public static final NDLoss loss = new NDLoss();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Bitwise namespace - operations related to bitwise manipulation of arrays
|
* Bitwise namespace - operations related to bitwise manipulation of arrays
|
||||||
*/
|
*/
|
||||||
|
@ -162,6 +164,11 @@ public class Nd4j {
|
||||||
return nn;
|
return nn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loss function namespace - operations related to loss functions.
|
||||||
|
*/
|
||||||
|
public static NDLoss loss(){ return loss; }
|
||||||
|
|
||||||
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
||||||
private final static String CONVOLUTION_OPS = "convops";
|
private final static String CONVOLUTION_OPS = "convops";
|
||||||
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
||||||
|
|
|
@ -0,0 +1,483 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.NDValidation;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class NDLoss {
|
||||||
|
public NDLoss() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @return output loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce) {
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "label", label);
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, lossReduce))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "label", label);
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("absoluteDifference", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is<br>
|
||||||
|
* equivalent to cosine distance when both the predictions and labels are normalized.<br>
|
||||||
|
* <b>Note</b>: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.<br>
|
||||||
|
* If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)<br>
|
||||||
|
* along the cosine distance dimension (with keepDims=true).<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param dimension Dimension to perform the cosine distance over
|
||||||
|
* @return output Cosine distance loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce, int dimension) {
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "label", label);
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, lossReduce, dimension))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is<br>
|
||||||
|
* equivalent to cosine distance when both the predictions and labels are normalized.<br>
|
||||||
|
* <b>Note</b>: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.<br>
|
||||||
|
* If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)<br>
|
||||||
|
* along the cosine distance dimension (with keepDims=true).<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||||
|
* @param dimension Dimension to perform the cosine distance over
|
||||||
|
* @return output Cosine distance loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
int dimension) {
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "label", label);
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("cosineDistance", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hinge loss: a loss function used for training classifiers.<br>
|
||||||
|
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||||
|
* from the user specified {0,1}. Note that Labels should be provided with values {0,1}.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce) {
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, lossReduce))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hinge loss: a loss function used for training classifiers.<br>
|
||||||
|
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||||
|
* from the user specified {0,1}. Note that Labels should be provided with values {0,1}.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("hingeLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,<br>
|
||||||
|
* though is less sensitive to outliers than squared error.<br>
|
||||||
|
* Huber loss implements:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}<br>
|
||||||
|
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param delta Loss function delta value
|
||||||
|
* @return output Huber loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce, double delta) {
|
||||||
|
NDValidation.validateNumerical("huberLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("huberLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("huberLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, lossReduce, delta))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,<br>
|
||||||
|
* though is less sensitive to outliers than squared error.<br>
|
||||||
|
* Huber loss implements:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}<br>
|
||||||
|
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param delta Loss function delta value
|
||||||
|
* @return output Huber loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, double delta) {
|
||||||
|
NDValidation.validateNumerical("huberLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("huberLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("huberLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* L2 loss: 1/2 * sum(x^2)<br>
|
||||||
|
*
|
||||||
|
* @param var Variable to calculate L2 loss of (NUMERIC type)
|
||||||
|
* @return output L2 loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray l2Loss(INDArray var) {
|
||||||
|
NDValidation.validateNumerical("l2Loss", "var", var);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.L2Loss(var))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:<br>
|
||||||
|
* {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param epsilon epsilon
|
||||||
|
* @return output Log loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce, double epsilon) {
|
||||||
|
NDValidation.validateNumerical("logLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("logLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("logLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, lossReduce, epsilon))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:<br>
|
||||||
|
* {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param epsilon epsilon
|
||||||
|
* @return output Log loss (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) {
|
||||||
|
NDValidation.validateNumerical("logLoss", "label", label);
|
||||||
|
NDValidation.validateNumerical("logLoss", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("logLoss", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log poisson loss: a loss function used for training classifiers.<br>
|
||||||
|
* Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce, boolean full) {
|
||||||
|
NDValidation.validateNumerical("logPoisson", "label", label);
|
||||||
|
NDValidation.validateNumerical("logPoisson", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("logPoisson", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, lossReduce, full))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Log poisson loss: a loss function used for training classifiers.<br>
|
||||||
|
* Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, boolean full) {
|
||||||
|
NDValidation.validateNumerical("logPoisson", "label", label);
|
||||||
|
NDValidation.validateNumerical("logPoisson", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("logPoisson", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mean pairwise squared error.<br>
|
||||||
|
* MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.<br>
|
||||||
|
* For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:<br>
|
||||||
|
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @return output Loss variable, scalar output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce) {
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mean pairwise squared error.<br>
|
||||||
|
* MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.<br>
|
||||||
|
* For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:<br>
|
||||||
|
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||||
|
* @return output Loss variable, scalar output (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
|
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||||
|
* this is the mean squared error loss function.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||||
|
LossReduce lossReduce) {
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "label", label);
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||||
|
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||||
|
* this is the mean squared error loss function.<br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictions Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "label", label);
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
|
||||||
|
NDValidation.validateNumerical("meanSquaredError", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")<br>
|
||||||
|
* and implements the binary cross entropy loss function. This implementation is numerically more stable than using<br>
|
||||||
|
* standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.<br>
|
||||||
|
* Implements:<br>
|
||||||
|
* {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}<br>
|
||||||
|
* though this is done in a mathematically equivalent but more numerical stable form.<br>
|
||||||
|
* <br>
|
||||||
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code numClasses = labels.size(1);<br>
|
||||||
|
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictionLogits Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights,
|
||||||
|
LossReduce lossReduce, double labelSmoothing) {
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, lossReduce, labelSmoothing))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")<br>
|
||||||
|
* and implements the binary cross entropy loss function. This implementation is numerically more stable than using<br>
|
||||||
|
* standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.<br>
|
||||||
|
* Implements:<br>
|
||||||
|
* {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}<br>
|
||||||
|
* though this is done in a mathematically equivalent but more numerical stable form.<br>
|
||||||
|
* <br>
|
||||||
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code numClasses = labels.size(1);<br>
|
||||||
|
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param label Label array (NUMERIC type)
|
||||||
|
* @param predictionLogits Predictions array (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
|
||||||
|
NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
|
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
|
* otherwise, the output is a scalar.<br>
|
||||||
|
* <p><br>
|
||||||
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code numClasses = labels.size(1);<br>
|
||||||
|
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||||
|
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||||
|
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
|
||||||
|
INDArray weights, LossReduce lossReduce, double labelSmoothing) {
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||||
|
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||||
|
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||||
|
* otherwise, the output is a scalar.<br>
|
||||||
|
* <p><br>
|
||||||
|
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||||
|
* <pre><br>
|
||||||
|
* {@code numClasses = labels.size(1);<br>
|
||||||
|
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}<br>
|
||||||
|
* </pre><br>
|
||||||
|
*
|
||||||
|
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||||
|
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||||
|
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
|
||||||
|
INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
|
||||||
|
NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable<br>
|
||||||
|
* is represented as an integer array instead of the equivalent one-hot array.<br>
|
||||||
|
* i.e., if logits are rank N, then labels have rank N-1<br>
|
||||||
|
*
|
||||||
|
* @param logits Logits array ("pre-softmax activations") (NUMERIC type)
|
||||||
|
* @param labels Labels array. Must be an integer type. (INT type)
|
||||||
|
* @return output Softmax cross entropy (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray sparseSoftmaxCrossEntropy(INDArray logits, INDArray labels) {
|
||||||
|
NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits);
|
||||||
|
NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(logits, labels))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Weighted cross entropy loss with logits<br>
|
||||||
|
*
|
||||||
|
* @param targets targets array (NUMERIC type)
|
||||||
|
* @param inputs input array (NUMERIC type)
|
||||||
|
* @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||||
|
* @return output Loss variable (NUMERIC type)
|
||||||
|
*/
|
||||||
|
public INDArray weightedCrossEntropyWithLogits(INDArray targets, INDArray inputs,
|
||||||
|
INDArray weights) {
|
||||||
|
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets);
|
||||||
|
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs);
|
||||||
|
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights);
|
||||||
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(targets, inputs, weights))[0];
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,453 @@
|
||||||
|
/* *****************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit k.k.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.factory.ops;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.autodiff.loss.LossReduce;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class NDLossTest extends BaseNd4jTest {
|
||||||
|
public NDLossTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering(){
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAbsoluteDifference() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction);
|
||||||
|
SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, wArr, reduction);
|
||||||
|
INDArray y2 = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, null, reduction);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCosineDistance() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
predictionsArr.diviColumnVector(predictionsArr.norm2(1));
|
||||||
|
labelsArr.diviColumnVector(labelsArr.norm2(1));
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().cosineDistance("loss", labels, predictions, w, reduction, 0);
|
||||||
|
SDVariable loss2 = sd.loss().cosineDistance("loss2", labels, predictions, null, reduction, 0);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, wArr, reduction, 0);
|
||||||
|
INDArray y2 = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, null, reduction, 0);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testHingeLoss() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().hingeLoss("loss", labels, predictions, w, reduction);
|
||||||
|
SDVariable loss2 = sd.loss().hingeLoss("loss2", labels, predictions, null, reduction);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, wArr, reduction);
|
||||||
|
INDArray y2 = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, null, reduction);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testHuberLoss() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().huberLoss("loss", labels, predictions, w, reduction, 0.02);
|
||||||
|
SDVariable loss2 = sd.loss().huberLoss("loss2", labels, predictions, null, reduction, 0.02);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().huberLoss(labelsArr, predictionsArr, wArr, reduction, 0.02);
|
||||||
|
INDArray y2 = Nd4j.loss().huberLoss(labelsArr, predictionsArr, null, reduction, 0.02);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testL2Loss() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().l2Loss("loss", predictions);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().l2Loss(predictionsArr);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLogLoss() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5));
|
||||||
|
predictionsArr = Nd4j.rand(predictionsArr.shape()).muli(0.8).addi(0.1);
|
||||||
|
|
||||||
|
double eps = 1e-7;
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().logLoss("loss", labels, predictions, w, reduction, eps);
|
||||||
|
SDVariable loss2 = sd.loss().logLoss("loss2", labels, predictions, null, reduction, eps);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
//TODO: Test fails. "Op [log_loss] execution failed"
|
||||||
|
INDArray y = Nd4j.loss().logLoss(labelsArr, predictionsArr, wArr, reduction, eps);
|
||||||
|
INDArray y2 = Nd4j.loss().logLoss(labelsArr, predictionsArr, null, reduction, eps);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLogPoisson() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction);
|
||||||
|
SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().logPoisson(labelsArr, predictionsArr, wArr, reduction, false);
|
||||||
|
INDArray y2 = Nd4j.loss().logPoisson(labelsArr, predictionsArr, null, reduction, false);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMeanPairwiseSquaredError() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction);
|
||||||
|
SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, wArr, reduction);
|
||||||
|
INDArray y2 = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, null, reduction);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMeanSquaredError() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction);
|
||||||
|
SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, wArr, reduction);
|
||||||
|
INDArray y2 = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, null, reduction);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSigmoidCrossEntropy() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.create(new double[][]{
|
||||||
|
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
|
||||||
|
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
|
||||||
|
SDVariable w = sd.var("weights", wArr);
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
double labelSmoothing = 0.01;
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
|
||||||
|
SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
|
||||||
|
INDArray y2 = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSoftmaxCrossEntropy() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
|
||||||
|
|
||||||
|
INDArray wArr = Nd4j.scalar(1.0); //TODO: This test fails with a complex weights array.
|
||||||
|
SDVariable w = null;
|
||||||
|
|
||||||
|
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
labelsArr.assign(0);
|
||||||
|
for (int i = 0; i < labelsArr.size(0); i++) {
|
||||||
|
labelsArr.putScalar(i, i % labelsArr.size(1), 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
double labelSmoothing = 0.0;
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
|
||||||
|
SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
INDArray y_exp2 = loss2.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
|
||||||
|
INDArray y2 = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
assertEquals(y_exp2, y2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSparseSoftmaxCrossEntropy() {
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
int nOut = 4;
|
||||||
|
int minibatch = 10;
|
||||||
|
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
|
||||||
|
SDVariable labels = sd.var("labels", DataType.INT32, -1);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
|
||||||
|
INDArray labelsArr = Nd4j.create(DataType.INT32, minibatch);
|
||||||
|
for( int i=0; i<minibatch; i++ ){
|
||||||
|
labelsArr.putScalar(i, i%nOut);
|
||||||
|
}
|
||||||
|
|
||||||
|
SDVariable loss = sd.loss().sparseSoftmaxCrossEntropy("loss", predictions, labels);
|
||||||
|
sd.associateArrayWithVariable(predictionsArr, predictions);
|
||||||
|
sd.associateArrayWithVariable(labelsArr, labels);
|
||||||
|
|
||||||
|
INDArray y_exp = loss.eval();
|
||||||
|
|
||||||
|
INDArray y = Nd4j.loss().sparseSoftmaxCrossEntropy(predictionsArr, labelsArr);
|
||||||
|
assertEquals(y_exp, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWeightedCrossEntropyWithLogits() {
|
||||||
|
// This one from SamediffTests.java
|
||||||
|
SameDiff sameDiff = SameDiff.create();
|
||||||
|
INDArray targets = Nd4j.create(new long[]{1, 5});
|
||||||
|
INDArray inputs = Nd4j.create(new long[]{1, 5});
|
||||||
|
INDArray weights = Nd4j.create(new long[]{1, 5});
|
||||||
|
|
||||||
|
SDVariable sdInputs = sameDiff.var("inputs", inputs);
|
||||||
|
SDVariable sdWeights = sameDiff.var("weights", weights);
|
||||||
|
SDVariable sdTargets = sameDiff.var("targets", targets);
|
||||||
|
|
||||||
|
SDVariable res = sameDiff.loss().weightedCrossEntropyWithLogits(sdTargets, sdInputs, sdWeights);
|
||||||
|
|
||||||
|
INDArray resultArray = res.eval();
|
||||||
|
assertArrayEquals(new long[]{1, 5}, resultArray.shape());
|
||||||
|
|
||||||
|
// Make sure the INDArray interface produces the same result.
|
||||||
|
INDArray y = Nd4j.loss().weightedCrossEntropyWithLogits(targets, inputs, weights);
|
||||||
|
assertEquals(resultArray , y);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue