parent
256c9d20b0
commit
0caf50f80f
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
public class SDLoss extends SDOps {
|
||||
public SDLoss(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* helper to refactor duplicate code
|
||||
*/
|
||||
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
|
||||
String weightName = (name == null) ? null : name + "/weight";
|
||||
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
|
||||
*/
|
||||
|
@ -60,12 +69,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||
validateFloatingPoint("absolute difference loss", "predictions", predictions);
|
||||
validateNumerical("absolute difference loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -105,12 +109,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
|
||||
validateFloatingPoint("cosine distance loss", "predictions", predictions);
|
||||
validateNumerical("cosine distance loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -192,12 +191,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
|
||||
validateFloatingPoint("huber loss", "predictions", predictions);
|
||||
validateNumerical("huber loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -258,12 +252,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
|
||||
validateFloatingPoint("log loss", "predictions", predictions);
|
||||
validateNumerical("log loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -299,12 +288,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||
validateFloatingPoint("log poisson loss", "predictions", predictions);
|
||||
validateNumerical("log poisson loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -341,12 +325,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||
validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
|
||||
validateNumerical("log poisson (full) loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -382,12 +361,7 @@ public class SDLoss extends SDOps {
|
|||
public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||
validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
|
||||
validateNumerical("mean pairwise squared error loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -417,12 +391,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce) {
|
||||
validateFloatingPoint("mean squared error loss", "predictions", predictions);
|
||||
validateNumerical("mean squared error loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictions);
|
||||
SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -468,12 +437,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
||||
validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
|
||||
validateNumerical("sigmoid cross entropy loss", "labels", label);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, predictionLogits);
|
||||
SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -518,12 +482,7 @@ public class SDLoss extends SDOps {
|
|||
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
|
||||
validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
|
||||
validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
|
||||
if (weights == null) {
|
||||
String weightName = null;
|
||||
if(name != null)
|
||||
weightName = name + "/weight";
|
||||
weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0));
|
||||
}
|
||||
weights = getWeights(weights, name, logitPredictions);
|
||||
SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
|
||||
result = updateVariableNameAndReference(result, name);
|
||||
result.markAsLoss();
|
||||
|
@ -595,6 +554,4 @@ public class SDLoss extends SDOps {
|
|||
result.markAsLoss();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue