Loss namespace (#294)

* codegen for SDLoss. WIP.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* first pass of SDLoss.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip. Firsat cut of new op constructors. UNTESTED , NOT COMPILED YET.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* updated op signatures.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* add NDLoss tests.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* fix test.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* adds loss default params. factory.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Regenerate NDLoss

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* adds tests for null weights.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Last few tweaks

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Robert Altena <Rob@Ra-ai.com>
master
Alex Black 2020-03-06 16:07:22 +11:00 committed by GitHub
parent 7494117e90
commit e6a7b94fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1039 additions and 5 deletions

View File

@ -19,6 +19,9 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -35,6 +38,10 @@ public class AbsoluteDifferenceLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public AbsoluteDifferenceLoss(){ }
@Override

View File

@ -22,7 +22,9 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
@ -38,6 +40,16 @@ public abstract class BaseLoss extends DynamicCustomOp {
addArgs();
}
public BaseLoss(@NonNull LossReduce lossReduce, @NonNull INDArray predictions, INDArray weights, @NonNull INDArray labels){
super(new INDArray[]{predictions, getWeights(weights, predictions), labels}, null);
this.lossReduce = lossReduce;
addArgs();
}
protected static INDArray getWeights(INDArray weights, INDArray predictions){
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
}
protected BaseLoss(){ }
protected void addArgs(){

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -38,6 +39,12 @@ public class CosineDistanceLoss extends BaseLoss {
this.addIArgument(dimension);
}
public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){
super(lossReduce, predictions, weights, labels);
this.dimension = dimension;
this.addIArgument(dimension);
}
public CosineDistanceLoss(){ }
@Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -35,6 +36,10 @@ public class HingeLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public HingeLoss(){ }
@Override

View File

@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -40,6 +41,12 @@ public class HuberLoss extends BaseLoss {
tArguments.add(delta);
}
public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){
super(lossReduce, predictions, weights, labels);
this.delta = delta;
tArguments.add(delta);
}
public HuberLoss(){ }
@Override

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
@ -36,6 +37,10 @@ public class L2Loss extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{var});
}
public L2Loss(INDArray var){
super(new INDArray[]{var}, null);
}
@Override
public String opName() {
return "l2_loss";

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -40,6 +41,13 @@ public class LogLoss extends BaseLoss {
addTArgument(epsilon);
}
public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){
super(lossReduce, predictions, weights, labels);
this.epsilon = epsilon;
addTArgument(epsilon);
}
public LogLoss(){ }
@Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -43,6 +44,12 @@ public class LogPoissonLoss extends BaseLoss {
addArgs();
}
public LogPoissonLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, boolean full){
super(lossReduce, predictions, weights, labels);
this.full = full;
addArgs();
}
public LogPoissonLoss(){ }
protected void addArgs(){

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -33,6 +34,10 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public MeanPairwiseSquaredErrorLoss(){ }
@Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
@ -35,6 +36,10 @@ public class MeanSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels);
}
public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public MeanSquaredErrorLoss(){ }
@Override

View File

@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@ -54,6 +55,12 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
this(sameDiff, reductionMode, logits, weights, labels, 0.0);
}
public SigmoidCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
super(lossReduce, predictions, weights, labels);
this.labelSmoothing = labelSmoothing;
addArgs();
}
public void addArgs() {
super.addArgs();
addTArgument(labelSmoothing);

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@ -56,6 +57,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
this(sameDiff, lossReduce, logits, weights, labels, 0.0);
}
public SoftmaxCrossEntropyLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double labelSmoothing){
super(lossReduce, predictions, weights, labels);
this.labelSmoothing = labelSmoothing;
addArgs();
}
public void addArgs() {
super.addArgs();

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.loss;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@ -24,6 +25,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
@ -43,10 +45,13 @@ import java.util.*;
@NoArgsConstructor
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull SameDiff sameDiff, @NonNull SDVariable logits, @NonNull SDVariable labels) {
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}
public SparseSoftmaxCrossEntropyLossWithLogits(@NonNull INDArray logits, @NonNull INDArray labels){
super(new INDArray[]{labels, logits}, null);
}
public void addArgs() {
}

View File

@ -17,11 +17,13 @@
package org.nd4j.linalg.api.ops.impl.loss;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
@ -43,6 +45,9 @@ public class WeightedCrossEntropyLoss extends DynamicCustomOp {
this.sameDiff = sameDiff;
}
public WeightedCrossEntropyLoss(@NonNull INDArray targets, @NonNull INDArray inputs, @NonNull INDArray weights){
super(new INDArray[] {targets, inputs, weights}, null);
}
@Override
public String opName() {

View File

@ -16,10 +16,7 @@
package org.nd4j.linalg.factory;
import org.nd4j.linalg.factory.ops.NDBitwise;
import org.nd4j.linalg.factory.ops.NDMath;
import org.nd4j.linalg.factory.ops.NDNN;
import org.nd4j.linalg.factory.ops.NDRandom;
import org.nd4j.linalg.factory.ops.*;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import lombok.NonNull;
@ -134,6 +131,11 @@ public class Nd4j {
*/
public static final NDNN nn = new NDNN();
/**
* Loss function namespace - operations related to loss functions.
*/
public static final NDLoss loss = new NDLoss();
/**
* Bitwise namespace - operations related to bitwise manipulation of arrays
*/
@ -162,6 +164,11 @@ public class Nd4j {
return nn;
}
/**
* Loss function namespace - operations related to loss functions.
*/
public static NDLoss loss(){ return loss; }
private final static String DATA_BUFFER_OPS = "databufferfactory";
private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/

View File

@ -0,0 +1,483 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;
public class NDLoss {
public NDLoss() {
}
/**
* Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @return output loss variable (NUMERIC type)
*/
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce) {
NDValidation.validateNumerical("absoluteDifference", "label", label);
NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
NDValidation.validateNumerical("absoluteDifference", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, lossReduce))[0];
}
/**
* Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output loss variable (NUMERIC type)
*/
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights) {
NDValidation.validateNumerical("absoluteDifference", "label", label);
NDValidation.validateNumerical("absoluteDifference", "predictions", predictions);
NDValidation.validateNumerical("absoluteDifference", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
}
/**
* Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is<br>
* equivalent to cosine distance when both the predictions and labels are normalized.<br>
* <b>Note</b>: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.<br>
* If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)<br>
* along the cosine distance dimension (with keepDims=true).<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param dimension Dimension to perform the cosine distance over
* @return output Cosine distance loss (NUMERIC type)
*/
public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce, int dimension) {
NDValidation.validateNumerical("cosineDistance", "label", label);
NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
NDValidation.validateNumerical("cosineDistance", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, lossReduce, dimension))[0];
}
/**
* Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is<br>
* equivalent to cosine distance when both the predictions and labels are normalized.<br>
* <b>Note</b>: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.<br>
* If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)<br>
* along the cosine distance dimension (with keepDims=true).<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
* @param dimension Dimension to perform the cosine distance over
* @return output Cosine distance loss (NUMERIC type)
*/
public INDArray cosineDistance(INDArray label, INDArray predictions, INDArray weights,
int dimension) {
NDValidation.validateNumerical("cosineDistance", "label", label);
NDValidation.validateNumerical("cosineDistance", "predictions", predictions);
NDValidation.validateNumerical("cosineDistance", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
}
/**
* Hinge loss: a loss function used for training classifiers.<br>
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
* from the user specified {0,1}. Note that Labels should be provided with values {0,1}.<br>
*
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @return output Loss variable (NUMERIC type)
*/
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce) {
NDValidation.validateNumerical("hingeLoss", "label", label);
NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
NDValidation.validateNumerical("hingeLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, lossReduce))[0];
}
/**
* Hinge loss: a loss function used for training classifiers.<br>
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
* from the user specified {0,1}. Note that Labels should be provided with values {0,1}.<br>
*
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output Loss variable (NUMERIC type)
*/
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights) {
NDValidation.validateNumerical("hingeLoss", "label", label);
NDValidation.validateNumerical("hingeLoss", "predictions", predictions);
NDValidation.validateNumerical("hingeLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
}
/**
* Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,<br>
* though is less sensitive to outliers than squared error.<br>
* Huber loss implements:<br>
* <pre><br>
* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}<br>
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}<br>
* </pre><br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param delta Loss function delta value
* @return output Huber loss (NUMERIC type)
*/
public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce, double delta) {
NDValidation.validateNumerical("huberLoss", "label", label);
NDValidation.validateNumerical("huberLoss", "predictions", predictions);
NDValidation.validateNumerical("huberLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, lossReduce, delta))[0];
}
/**
* Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,<br>
* though is less sensitive to outliers than squared error.<br>
* Huber loss implements:<br>
* <pre><br>
* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}<br>
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}<br>
* </pre><br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param delta Loss function delta value
* @return output Huber loss (NUMERIC type)
*/
public INDArray huberLoss(INDArray label, INDArray predictions, INDArray weights, double delta) {
NDValidation.validateNumerical("huberLoss", "label", label);
NDValidation.validateNumerical("huberLoss", "predictions", predictions);
NDValidation.validateNumerical("huberLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta))[0];
}
/**
* L2 loss: 1/2 * sum(x^2)<br>
*
* @param var Variable to calculate L2 loss of (NUMERIC type)
* @return output L2 loss (NUMERIC type)
*/
public INDArray l2Loss(INDArray var) {
NDValidation.validateNumerical("l2Loss", "var", var);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.L2Loss(var))[0];
}
/**
* Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:<br>
* {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param epsilon epsilon
* @return output Log loss (NUMERIC type)
*/
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce, double epsilon) {
NDValidation.validateNumerical("logLoss", "label", label);
NDValidation.validateNumerical("logLoss", "predictions", predictions);
NDValidation.validateNumerical("logLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, lossReduce, epsilon))[0];
}
/**
* Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:<br>
* {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param epsilon epsilon
* @return output Log loss (NUMERIC type)
*/
public INDArray logLoss(INDArray label, INDArray predictions, INDArray weights, double epsilon) {
NDValidation.validateNumerical("logLoss", "label", label);
NDValidation.validateNumerical("logLoss", "predictions", predictions);
NDValidation.validateNumerical("logLoss", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, epsilon))[0];
}
/**
* Log poisson loss: a loss function used for training classifiers.<br>
* Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.<br>
*
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
* @return output Loss variable (NUMERIC type)
*/
public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce, boolean full) {
NDValidation.validateNumerical("logPoisson", "label", label);
NDValidation.validateNumerical("logPoisson", "predictions", predictions);
NDValidation.validateNumerical("logPoisson", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, lossReduce, full))[0];
}
/**
* Log poisson loss: a loss function used for training classifiers.<br>
* Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.<br>
*
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
* @return output Loss variable (NUMERIC type)
*/
public INDArray logPoisson(INDArray label, INDArray predictions, INDArray weights, boolean full) {
NDValidation.validateNumerical("logPoisson", "label", label);
NDValidation.validateNumerical("logPoisson", "predictions", predictions);
NDValidation.validateNumerical("logPoisson", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full))[0];
}
/**
* Mean pairwise squared error.<br>
* MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.<br>
* For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:<br>
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @return output Loss variable, scalar output (NUMERIC type)
*/
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce) {
NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
}
/**
* Mean pairwise squared error.<br>
* MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.<br>
* For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:<br>
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
* @return output Loss variable, scalar output (NUMERIC type)
*/
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights) {
NDValidation.validateNumerical("meanPairwiseSquaredError", "label", label);
NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions);
NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
}
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @return output Loss variable (NUMERIC type)
*/
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
LossReduce lossReduce) {
NDValidation.validateNumerical("meanSquaredError", "label", label);
NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
NDValidation.validateNumerical("meanSquaredError", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, lossReduce))[0];
}
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output Loss variable (NUMERIC type)
*/
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights) {
NDValidation.validateNumerical("meanSquaredError", "label", label);
NDValidation.validateNumerical("meanSquaredError", "predictions", predictions);
NDValidation.validateNumerical("meanSquaredError", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
}
/**
* Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")<br>
* and implements the binary cross entropy loss function. This implementation is numerically more stable than using<br>
* standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.<br>
* Implements:<br>
* {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}<br>
* though this is done in a mathematically equivalent but more numerical stable form.<br>
* <br>
* When label smoothing is > 0, the following label smoothing is used:<br>
* <pre><br>
* {@code numClasses = labels.size(1);<br>
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}<br>
* </pre><br>
*
* @param label Label array (NUMERIC type)
* @param predictionLogits Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights,
LossReduce lossReduce, double labelSmoothing) {
NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, lossReduce, labelSmoothing))[0];
}
/**
* Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")<br>
* and implements the binary cross entropy loss function. This implementation is numerically more stable than using<br>
* standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.<br>
* Implements:<br>
* {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}<br>
* though this is done in a mathematically equivalent but more numerical stable form.<br>
* <br>
* When label smoothing is > 0, the following label smoothing is used:<br>
* <pre><br>
* {@code numClasses = labels.size(1);<br>
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}<br>
* </pre><br>
*
* @param label Label array (NUMERIC type)
* @param predictionLogits Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output Loss variable (NUMERIC type)
*/
public INDArray sigmoidCrossEntropy(INDArray label, INDArray predictionLogits, INDArray weights) {
NDValidation.validateNumerical("sigmoidCrossEntropy", "label", label);
NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits);
NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
}
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
* <pre><br>
* {@code numClasses = labels.size(1);<br>
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}<br>
* </pre><br>
*
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
INDArray weights, LossReduce lossReduce, double labelSmoothing) {
NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing))[0];
}
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
* <pre><br>
* {@code numClasses = labels.size(1);<br>
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}<br>
* </pre><br>
*
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output Loss variable (NUMERIC type)
*/
public INDArray softmaxCrossEntropy(INDArray oneHotLabels, INDArray logitPredictions,
INDArray weights) {
NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels);
NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions);
NDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0))[0];
}
/**
* As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable<br>
* is represented as an integer array instead of the equivalent one-hot array.<br>
* i.e., if logits are rank N, then labels have rank N-1<br>
*
* @param logits Logits array ("pre-softmax activations") (NUMERIC type)
* @param labels Labels array. Must be an integer type. (INT type)
* @return output Softmax cross entropy (NUMERIC type)
*/
public INDArray sparseSoftmaxCrossEntropy(INDArray logits, INDArray labels) {
NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits);
NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(logits, labels))[0];
}
/**
* Weighted cross entropy loss with logits<br>
*
* @param targets targets array (NUMERIC type)
* @param inputs input array (NUMERIC type)
* @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @return output Loss variable (NUMERIC type)
*/
public INDArray weightedCrossEntropyWithLogits(INDArray targets, INDArray inputs,
INDArray weights) {
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets);
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs);
NDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(targets, inputs, weights))[0];
}
}

View File

@ -0,0 +1,453 @@
/* *****************************************************************************
* Copyright (c) 2019-2020 Konduit k.k.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.factory.ops;
import org.junit.Test;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class NDLossTest extends BaseNd4jTest {
public NDLossTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering(){
return 'c';
}
@Test
public void testAbsoluteDifference() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction);
SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, wArr, reduction);
INDArray y2 = Nd4j.loss().absoluteDifference(labelsArr, predictionsArr, null, reduction);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testCosineDistance() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
predictionsArr.diviColumnVector(predictionsArr.norm2(1));
labelsArr.diviColumnVector(labelsArr.norm2(1));
SDVariable loss = sd.loss().cosineDistance("loss", labels, predictions, w, reduction, 0);
SDVariable loss2 = sd.loss().cosineDistance("loss2", labels, predictions, null, reduction, 0);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, wArr, reduction, 0);
INDArray y2 = Nd4j.loss().cosineDistance(labelsArr, predictionsArr, null, reduction, 0);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testHingeLoss() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().hingeLoss("loss", labels, predictions, w, reduction);
SDVariable loss2 = sd.loss().hingeLoss("loss2", labels, predictions, null, reduction);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, wArr, reduction);
INDArray y2 = Nd4j.loss().hingeLoss(labelsArr, predictionsArr, null, reduction);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testHuberLoss() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().huberLoss("loss", labels, predictions, w, reduction, 0.02);
SDVariable loss2 = sd.loss().huberLoss("loss2", labels, predictions, null, reduction, 0.02);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().huberLoss(labelsArr, predictionsArr, wArr, reduction, 0.02);
INDArray y2 = Nd4j.loss().huberLoss(labelsArr, predictionsArr, null, reduction, 0.02);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testL2Loss() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().l2Loss("loss", predictions);
sd.associateArrayWithVariable(predictionsArr, predictions);
INDArray y_exp = loss.eval();
INDArray y = Nd4j.loss().l2Loss(predictionsArr);
assertEquals(y_exp, y);
}
@Test
public void testLogLoss() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5));
predictionsArr = Nd4j.rand(predictionsArr.shape()).muli(0.8).addi(0.1);
double eps = 1e-7;
SDVariable loss = sd.loss().logLoss("loss", labels, predictions, w, reduction, eps);
SDVariable loss2 = sd.loss().logLoss("loss2", labels, predictions, null, reduction, eps);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
//TODO: Test fails. "Op [log_loss] execution failed"
INDArray y = Nd4j.loss().logLoss(labelsArr, predictionsArr, wArr, reduction, eps);
INDArray y2 = Nd4j.loss().logLoss(labelsArr, predictionsArr, null, reduction, eps);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testLogPoisson() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction);
SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().logPoisson(labelsArr, predictionsArr, wArr, reduction, false);
INDArray y2 = Nd4j.loss().logPoisson(labelsArr, predictionsArr, null, reduction, false);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testMeanPairwiseSquaredError() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction);
SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, wArr, reduction);
INDArray y2 = Nd4j.loss().meanPairwiseSquaredError(labelsArr, predictionsArr, null, reduction);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testMeanSquaredError() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction);
SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, wArr, reduction);
INDArray y2 = Nd4j.loss().meanSquaredError(labelsArr, predictionsArr, null, reduction);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testSigmoidCrossEntropy() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.create(new double[][]{
{0, 0, 0, 0}, {0, 0, 1, 1}, {1, 1, 0, 0}, {1, 1, 1, 1}, {1, 1, 1, 1},
{2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}, {2, 2, 2, 2}});
SDVariable w = sd.var("weights", wArr);
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
double labelSmoothing = 0.01;
SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
INDArray y2 = Nd4j.loss().sigmoidCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testSoftmaxCrossEntropy() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
INDArray wArr = Nd4j.scalar(1.0); //TODO: This test fails with a complex weights array.
SDVariable w = null;
LossReduce reduction = LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT;
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
labelsArr.assign(0);
for (int i = 0; i < labelsArr.size(0); i++) {
labelsArr.putScalar(i, i % labelsArr.size(1), 1.0);
}
double labelSmoothing = 0.0;
SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing);
SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y_exp2 = loss2.eval();
INDArray y = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, wArr, reduction, labelSmoothing);
INDArray y2 = Nd4j.loss().softmaxCrossEntropy(labelsArr, predictionsArr, null, reduction, labelSmoothing);
assertEquals(y_exp, y);
assertEquals(y_exp2, y2);
}
@Test
public void testSparseSoftmaxCrossEntropy() {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, minibatch, nOut);
SDVariable labels = sd.var("labels", DataType.INT32, -1);
INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut);
INDArray labelsArr = Nd4j.create(DataType.INT32, minibatch);
for( int i=0; i<minibatch; i++ ){
labelsArr.putScalar(i, i%nOut);
}
SDVariable loss = sd.loss().sparseSoftmaxCrossEntropy("loss", predictions, labels);
sd.associateArrayWithVariable(predictionsArr, predictions);
sd.associateArrayWithVariable(labelsArr, labels);
INDArray y_exp = loss.eval();
INDArray y = Nd4j.loss().sparseSoftmaxCrossEntropy(predictionsArr, labelsArr);
assertEquals(y_exp, y);
}
@Test
public void testWeightedCrossEntropyWithLogits() {
// This one from SamediffTests.java
SameDiff sameDiff = SameDiff.create();
INDArray targets = Nd4j.create(new long[]{1, 5});
INDArray inputs = Nd4j.create(new long[]{1, 5});
INDArray weights = Nd4j.create(new long[]{1, 5});
SDVariable sdInputs = sameDiff.var("inputs", inputs);
SDVariable sdWeights = sameDiff.var("weights", weights);
SDVariable sdTargets = sameDiff.var("targets", targets);
SDVariable res = sameDiff.loss().weightedCrossEntropyWithLogits(sdTargets, sdInputs, sdWeights);
INDArray resultArray = res.eval();
assertArrayEquals(new long[]{1, 5}, resultArray.shape());
// Make sure the INDArray interface produces the same result.
INDArray y = Nd4j.loss().weightedCrossEntropyWithLogits(targets, inputs, weights);
assertEquals(resultArray , y);
}
}