SDLoss cleanup. (#180)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2020-01-23 20:22:06 +09:00 committed by Alex Black
parent 256c9d20b0
commit 0caf50f80f
1 changed files with 20 additions and 63 deletions

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*;
*
* @author Alex Black
*/
@SuppressWarnings("unused")
public class SDLoss extends SDOps {
public SDLoss(SameDiff sameDiff) {
super(sameDiff);
}
/**
* helper to refactor duplicate code
*/
private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
String weightName = (name == null) ? null : name + "/weight";
return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
/**
* See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
*/
@ -60,12 +69,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce) {
validateFloatingPoint("absolute difference loss", "predictions", predictions);
validateNumerical("absolute difference loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -105,12 +109,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
validateFloatingPoint("cosine distance loss", "predictions", predictions);
validateNumerical("cosine distance loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -192,12 +191,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
validateFloatingPoint("huber loss", "predictions", predictions);
validateNumerical("huber loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -258,12 +252,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
validateFloatingPoint("log loss", "predictions", predictions);
validateNumerical("log loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -299,12 +288,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce) {
validateFloatingPoint("log poisson loss", "predictions", predictions);
validateNumerical("log poisson loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -341,12 +325,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce) {
validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
validateNumerical("log poisson (full) loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -382,12 +361,7 @@ public class SDLoss extends SDOps {
public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
validateNumerical("mean pairwise squared error loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -417,12 +391,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce) {
validateFloatingPoint("mean squared error loss", "predictions", predictions);
validateNumerical("mean squared error loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
}
weights = getWeights(weights, name, predictions);
SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -468,12 +437,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
validateNumerical("sigmoid cross entropy loss", "labels", label);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0));
}
weights = getWeights(weights, name, predictionLogits);
SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -518,12 +482,7 @@ public class SDLoss extends SDOps {
SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
if (weights == null) {
String weightName = null;
if(name != null)
weightName = name + "/weight";
weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0));
}
weights = getWeights(weights, name, logitPredictions);
SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
result = updateVariableNameAndReference(result, name);
result.markAsLoss();
@ -595,6 +554,4 @@ public class SDLoss extends SDOps {
result.markAsLoss();
return result;
}
}