parent
256c9d20b0
commit
0caf50f80f
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@SuppressWarnings("unused")
|
||||||
public class SDLoss extends SDOps {
|
public class SDLoss extends SDOps {
|
||||||
public SDLoss(SameDiff sameDiff) {
|
public SDLoss(SameDiff sameDiff) {
|
||||||
super(sameDiff);
|
super(sameDiff);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* helper to refactor duplicate code
|
||||||
|
*/
|
||||||
|
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
||||||
|
String weightName = (name == null) ? null : name + "/weight";
|
||||||
|
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
|
* See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
|
||||||
*/
|
*/
|
||||||
|
@ -60,12 +69,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||||
validateFloatingPoint("absolute difference loss", "predictions", predictions);
|
validateFloatingPoint("absolute difference loss", "predictions", predictions);
|
||||||
validateNumerical("absolute difference loss", "labels", label);
|
validateNumerical("absolute difference loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
|
SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -105,12 +109,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
|
SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
|
||||||
validateFloatingPoint("cosine distance loss", "predictions", predictions);
|
validateFloatingPoint("cosine distance loss", "predictions", predictions);
|
||||||
validateNumerical("cosine distance loss", "labels", label);
|
validateNumerical("cosine distance loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
|
SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -192,12 +191,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
|
SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
|
||||||
validateFloatingPoint("huber loss", "predictions", predictions);
|
validateFloatingPoint("huber loss", "predictions", predictions);
|
||||||
validateNumerical("huber loss", "labels", label);
|
validateNumerical("huber loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
|
SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -258,12 +252,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
|
SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
|
||||||
validateFloatingPoint("log loss", "predictions", predictions);
|
validateFloatingPoint("log loss", "predictions", predictions);
|
||||||
validateNumerical("log loss", "labels", label);
|
validateNumerical("log loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
|
SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -299,12 +288,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||||
validateFloatingPoint("log poisson loss", "predictions", predictions);
|
validateFloatingPoint("log poisson loss", "predictions", predictions);
|
||||||
validateNumerical("log poisson loss", "labels", label);
|
validateNumerical("log poisson loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
|
SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -341,12 +325,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||||
validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
|
validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
|
||||||
validateNumerical("log poisson (full) loss", "labels", label);
|
validateNumerical("log poisson (full) loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
|
SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -382,12 +361,7 @@ public class SDLoss extends SDOps {
|
||||||
public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
|
public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||||
validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
|
validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
|
||||||
validateNumerical("mean pairwise squared error loss", "labels", label);
|
validateNumerical("mean pairwise squared error loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
|
SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -417,12 +391,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||||
validateFloatingPoint("mean squared error loss", "predictions", predictions);
|
validateFloatingPoint("mean squared error loss", "predictions", predictions);
|
||||||
validateNumerical("mean squared error loss", "labels", label);
|
validateNumerical("mean squared error loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
|
SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -468,12 +437,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
||||||
validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
|
validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
|
||||||
validateNumerical("sigmoid cross entropy loss", "labels", label);
|
validateNumerical("sigmoid cross entropy loss", "labels", label);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, predictionLogits);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
|
SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -518,12 +482,7 @@ public class SDLoss extends SDOps {
|
||||||
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
||||||
validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
|
validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
|
||||||
validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
|
validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
|
||||||
if (weights == null) {
|
weights = getWeights(weights, name, logitPredictions);
|
||||||
String weightName = null;
|
|
||||||
if(name != null)
|
|
||||||
weightName = name + "/weight";
|
|
||||||
weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0));
|
|
||||||
}
|
|
||||||
SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
|
SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
|
||||||
result = updateVariableNameAndReference(result, name);
|
result = updateVariableNameAndReference(result, name);
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
|
@ -595,6 +554,4 @@ public class SDLoss extends SDOps {
|
||||||
result.markAsLoss();
|
result.markAsLoss();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue