parent
							
								
									256c9d20b0
								
							
						
					
					
						commit
						0caf50f80f
					
				@ -1,4 +1,4 @@
 | 
				
			|||||||
/*******************************************************************************
 | 
					/* *****************************************************************************
 | 
				
			||||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
					 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
				
			||||||
 *
 | 
					 *
 | 
				
			||||||
 * This program and the accompanying materials are made available under the
 | 
					 * This program and the accompanying materials are made available under the
 | 
				
			||||||
@ -34,11 +34,20 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.*;
 | 
				
			|||||||
 *
 | 
					 *
 | 
				
			||||||
 * @author Alex Black
 | 
					 * @author Alex Black
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					@SuppressWarnings("unused")
 | 
				
			||||||
public class SDLoss extends SDOps {
 | 
					public class SDLoss extends SDOps {
 | 
				
			||||||
    public SDLoss(SameDiff sameDiff) {
 | 
					    public SDLoss(SameDiff sameDiff) {
 | 
				
			||||||
        super(sameDiff);
 | 
					        super(sameDiff);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * helper to refactor duplicate code
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){
 | 
				
			||||||
 | 
					        String weightName = (name == null) ? null : name + "/weight";
 | 
				
			||||||
 | 
					        return (weights == null) ? null : sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
     * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
 | 
					     * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
 | 
				
			||||||
     */
 | 
					     */
 | 
				
			||||||
@ -60,12 +69,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                         SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
					                                         SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
				
			||||||
        validateFloatingPoint("absolute difference loss", "predictions", predictions);
 | 
					        validateFloatingPoint("absolute difference loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("absolute difference loss", "labels", label);
 | 
					        validateNumerical("absolute difference loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
 | 
					        SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -105,12 +109,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                     SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
 | 
					                                     SDVariable weights, @NonNull LossReduce lossReduce, int dimension) {
 | 
				
			||||||
        validateFloatingPoint("cosine distance loss", "predictions", predictions);
 | 
					        validateFloatingPoint("cosine distance loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("cosine distance loss", "labels", label);
 | 
					        validateNumerical("cosine distance loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
 | 
					        SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -192,12 +191,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
 | 
					                                SDVariable weights, @NonNull LossReduce lossReduce, double delta) {
 | 
				
			||||||
        validateFloatingPoint("huber loss", "predictions", predictions);
 | 
					        validateFloatingPoint("huber loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("huber loss", "labels", label);
 | 
					        validateNumerical("huber loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
 | 
					        SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -258,12 +252,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                              SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
 | 
					                              SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) {
 | 
				
			||||||
        validateFloatingPoint("log loss", "predictions", predictions);
 | 
					        validateFloatingPoint("log loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("log loss", "labels", label);
 | 
					        validateNumerical("log loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
 | 
					        SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -299,12 +288,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                 SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
					                                 SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
				
			||||||
        validateFloatingPoint("log poisson loss", "predictions", predictions);
 | 
					        validateFloatingPoint("log poisson loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("log poisson loss", "labels", label);
 | 
					        validateNumerical("log poisson loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
 | 
					        SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -341,12 +325,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                     SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
					                                     SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
				
			||||||
        validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
 | 
					        validateFloatingPoint("log poisson (full) loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("log poisson (full) loss", "labels", label);
 | 
					        validateNumerical("log poisson (full) loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
 | 
					        SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -382,12 +361,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
    public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
					    public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
				
			||||||
        validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
 | 
					        validateFloatingPoint("main pairwise squared error loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("mean pairwise squared error loss", "labels", label);
 | 
					        validateNumerical("mean pairwise squared error loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
 | 
					        SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -417,12 +391,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                       SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
					                                       SDVariable weights, @NonNull LossReduce lossReduce) {
 | 
				
			||||||
        validateFloatingPoint("mean squared error loss", "predictions", predictions);
 | 
					        validateFloatingPoint("mean squared error loss", "predictions", predictions);
 | 
				
			||||||
        validateNumerical("mean squared error loss", "labels", label);
 | 
					        validateNumerical("mean squared error loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
 | 
					        SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -468,12 +437,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                          SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
 | 
					                                          SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
 | 
				
			||||||
        validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
 | 
					        validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits);
 | 
				
			||||||
        validateNumerical("sigmoid cross entropy loss", "labels", label);
 | 
					        validateNumerical("sigmoid cross entropy loss", "labels", label);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, predictionLogits);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(predictionLogits.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
 | 
					        SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -518,12 +482,7 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
                                          SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
 | 
					                                          SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) {
 | 
				
			||||||
        validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
 | 
					        validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions);
 | 
				
			||||||
        validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
 | 
					        validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels);
 | 
				
			||||||
        if (weights == null) {
 | 
					        weights = getWeights(weights, name, logitPredictions);
 | 
				
			||||||
            String weightName = null;
 | 
					 | 
				
			||||||
            if(name != null)
 | 
					 | 
				
			||||||
                weightName = name + "/weight";
 | 
					 | 
				
			||||||
            weights = sd.constant(weightName, Nd4j.scalar(logitPredictions.dataType(), 1.0));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
 | 
					        SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing);
 | 
				
			||||||
        result = updateVariableNameAndReference(result, name);
 | 
					        result = updateVariableNameAndReference(result, name);
 | 
				
			||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
@ -595,6 +554,4 @@ public class SDLoss extends SDOps {
 | 
				
			|||||||
        result.markAsLoss();
 | 
					        result.markAsLoss();
 | 
				
			||||||
        return result;
 | 
					        return result;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user