From 0caf50f80f6ee411fa8f83d91bf45910e700e512 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Thu, 23 Jan 2020 20:22:06 +0900 Subject: [PATCH] SDLoss cleanup. (#180) Signed-off-by: Robert Altena --- .../nd4j/autodiff/samediff/ops/SDLoss.java | 83 +++++-------------- 1 file changed, 20 insertions(+), 63 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index 70da070b8..96d76c52d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*; * * @author Alex Black */ +@SuppressWarnings("unused") public class SDLoss extends SDOps { public SDLoss(SameDiff sameDiff) { super(sameDiff); } + /** + * helper to refactor duplicate code + */ + private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ + String weightName = (name == null) ? null : name + "/weight"; + return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); + } + /** * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. */ @@ -60,12 +69,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("absolute difference loss", "predictions", predictions); validateNumerical("absolute difference loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -105,12 +109,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, int dimension) { validateFloatingPoint("cosine distance loss", "predictions", predictions); validateNumerical("cosine distance loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -192,12 +191,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double delta) { validateFloatingPoint("huber loss", "predictions", predictions); validateNumerical("huber loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -258,12 +252,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) { validateFloatingPoint("log loss", "predictions", predictions); validateNumerical("log loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -299,12 +288,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("log poisson loss", "predictions", predictions); validateNumerical("log poisson loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -341,12 +325,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("log poisson (full) loss", "predictions", predictions); validateNumerical("log poisson (full) loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -382,12 +361,7 @@ public class SDLoss extends SDOps { public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("main pairwise squared error loss", "predictions", predictions); validateNumerical("mean pairwise squared error loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -417,12 +391,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce) { validateFloatingPoint("mean squared error loss", "predictions", predictions); validateNumerical("mean squared error loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictions); SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -468,12 +437,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits); validateNumerical("sigmoid cross entropy loss", "labels", label); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0)); - } + weights = getWeights(weights, name, predictionLogits); SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -518,12 +482,7 @@ public class SDLoss extends SDOps { SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions); validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels); - if (weights == null) { - String weightName = null; - if(name != null) - weightName = name + "/weight"; - weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0)); - } + weights = getWeights(weights, name, logitPredictions); SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing); result = updateVariableNameAndReference(result, name); result.markAsLoss(); @@ -595,6 +554,4 @@ public class SDLoss extends SDOps { result.markAsLoss(); return result; } - - }