* #7977 deprecate legacy MultiLayerNetwork/ComputationGraph.params(boolean) method Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix bad test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Histogram mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix incorrect name handling in DifferentialFunction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Histogram fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Proper histogram fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * ToString/NDArrayStrings fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * JSON UTF8 serialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com>
687 lines
32 KiB
Java
687 lines
32 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.gradientcheck;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.TestUtils;
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
|
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.junit.Test;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
import org.nd4j.linalg.learning.config.NoOp;
|
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
|
import org.nd4j.linalg.lossfunctions.impl.*;
|
|
import org.nd4j.linalg.primitives.Pair;
|
|
import org.nd4j.linalg.util.ArrayUtil;
|
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
|
|
import java.io.IOException;
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
import java.util.Random;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
import static org.junit.Assert.assertTrue;
|
|
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
|
|
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
|
|
|
/**
|
|
* Created by Alex on 12/09/2016.
|
|
*/
|
|
@Slf4j
|
|
public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|
|
|
static {
|
|
Nd4j.setDataType(DataType.DOUBLE);
|
|
}
|
|
|
|
private static final boolean PRINT_RESULTS = true;
|
|
private static final boolean RETURN_ON_FIRST_FAILURE = false;
|
|
private static final double DEFAULT_EPS = 1e-6;
|
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
|
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
|
|
|
@Test
|
|
public void lossFunctionGradientCheck() {
|
|
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
|
new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(),
|
|
new LossL1(), new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(),
|
|
new LossMAPE(), new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(),
|
|
new LossMSLE(), new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(),
|
|
new LossPoisson(), new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0),
|
|
new LossFMeasure(), new LossFMeasure(2.0),
|
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
|
new LossMultiLabel(), new LossWasserstein(),
|
|
};
|
|
|
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
|
Activation.SIGMOID, //xent
|
|
Activation.TANH, //cosine
|
|
Activation.TANH, //hinge -> trying to predict 1 or -1
|
|
Activation.SIGMOID, //kld -> probab so should be between 0 and 1
|
|
Activation.SOFTMAX, //kld + softmax
|
|
Activation.TANH, //l1
|
|
Activation.RATIONALTANH, //l1
|
|
Activation.SOFTMAX, //l1 + softmax
|
|
Activation.TANH, //l2
|
|
Activation.SOFTMAX, //l2 + softmax
|
|
Activation.IDENTITY, //mae
|
|
Activation.SOFTMAX, //mae + softmax
|
|
Activation.IDENTITY, //mape
|
|
Activation.SOFTMAX, //mape + softmax
|
|
Activation.SOFTMAX, //mcxent
|
|
Activation.IDENTITY, //mse
|
|
Activation.SOFTMAX, //mse + softmax
|
|
Activation.SIGMOID, //msle - requires positive labels/activations due to log
|
|
Activation.SOFTMAX, //msle + softmax
|
|
Activation.SIGMOID, //nll
|
|
Activation.SOFTMAX, //nll + softmax
|
|
Activation.SIGMOID, //poisson - requires positive predictions due to log... not sure if this is the best option
|
|
Activation.TANH, //squared hinge
|
|
Activation.SIGMOID, //f-measure (binary, single sigmoid output)
|
|
Activation.SIGMOID, //f-measure (binary, single sigmoid output)
|
|
Activation.SOFTMAX, //f-measure (binary, 2-label softmax output)
|
|
Activation.SOFTMAX, //f-measure (binary, 2-label softmax output)
|
|
Activation.IDENTITY, // MixtureDensity
|
|
Activation.TANH, // MixtureDensity + tanh
|
|
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
|
Activation.IDENTITY // Wasserstein
|
|
};
|
|
|
|
int[] nOut = new int[] {1, //xent
|
|
3, //xent
|
|
5, //cosine
|
|
3, //hinge
|
|
3, //kld
|
|
3, //kld + softmax
|
|
3, //l1
|
|
3, //l1
|
|
3, //l1 + softmax
|
|
3, //l2
|
|
3, //l2 + softmax
|
|
3, //mae
|
|
3, //mae + softmax
|
|
3, //mape
|
|
3, //mape + softmax
|
|
3, //mcxent
|
|
3, //mse
|
|
3, //mse + softmax
|
|
3, //msle
|
|
3, //msle + softmax
|
|
3, //nll
|
|
3, //nll + softmax
|
|
3, //poisson
|
|
3, //squared hinge
|
|
1, //f-measure (binary, single sigmoid output)
|
|
1, //f-measure (binary, single sigmoid output)
|
|
2, //f-measure (binary, 2-label softmax output)
|
|
2, //f-measure (binary, 2-label softmax output)
|
|
10, // Mixture Density
|
|
10, // Mixture Density + tanh
|
|
10, // MultiLabel
|
|
2, // Wasserstein
|
|
};
|
|
|
|
int[] minibatchSizes = new int[] {1, 3};
|
|
|
|
|
|
List<String> passed = new ArrayList<>();
|
|
List<String> failed = new ArrayList<>();
|
|
|
|
for (int i = 0; i < lossFunctions.length; i++) {
|
|
for (int j = 0; j < minibatchSizes.length; j++) {
|
|
String testName = lossFunctions[i] + " - " + outputActivationFn[i] + " - minibatchSize = "
|
|
+ minibatchSizes[j];
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
.dataType(DataType.DOUBLE)
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345)
|
|
.updater(new NoOp())
|
|
.dist(new UniformDistribution(-2, 2)).list()
|
|
.layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build())
|
|
.layer(1, new OutputLayer.Builder().lossFunction(lossFunctions[i])
|
|
.activation(outputActivationFn[i]).nIn(4).nOut(nOut[i]).build())
|
|
.validateOutputLayerConfig(false)
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
INDArray[] inOut = getFeaturesAndLabels(lossFunctions[i], minibatchSizes[j], 4, nOut[i], 12345);
|
|
INDArray input = inOut[0];
|
|
INDArray labels = inOut[1];
|
|
|
|
log.info(" ***** Starting test: {} *****", testName);
|
|
// System.out.println(Arrays.toString(labels.data().asDouble()));
|
|
// System.out.println(Arrays.toString(net.output(input,false).data().asDouble()));
|
|
// System.out.println(net.score(new DataSet(input,labels)));
|
|
|
|
boolean gradOK;
|
|
try {
|
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
failed.add(testName + "\t" + "EXCEPTION");
|
|
continue;
|
|
}
|
|
|
|
if (gradOK) {
|
|
passed.add(testName);
|
|
} else {
|
|
failed.add(testName);
|
|
}
|
|
|
|
System.out.println("\n\n");
|
|
TestUtils.testModelSerialization(net);
|
|
}
|
|
}
|
|
|
|
|
|
System.out.println("---- Passed ----");
|
|
for (String s : passed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
System.out.println("---- Failed ----");
|
|
for (String s : failed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
assertEquals("Tests failed", 0, failed.size());
|
|
}
|
|
|
|
@Test
|
|
public void lossFunctionGradientCheckLossLayer() {
|
|
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
|
new LossCosineProximity(), new LossHinge(), new LossKLD(), new LossKLD(), new LossL1(),
|
|
new LossL1(), new LossL2(), new LossL2(), new LossMAE(), new LossMAE(), new LossMAPE(),
|
|
new LossMAPE(), new LossMCXENT(), new LossMSE(), new LossMSE(), new LossMSLE(), new LossMSLE(),
|
|
new LossNegativeLogLikelihood(), new LossNegativeLogLikelihood(), new LossPoisson(),
|
|
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
|
|
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
|
new LossMultiLabel(), new LossWasserstein()
|
|
};
|
|
|
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
|
Activation.SIGMOID, //xent
|
|
Activation.TANH, //cosine
|
|
Activation.TANH, //hinge -> trying to predict 1 or -1
|
|
Activation.SIGMOID, //kld -> probab so should be between 0 and 1
|
|
Activation.SOFTMAX, //kld + softmax
|
|
Activation.TANH, //l1
|
|
Activation.SOFTMAX, //l1 + softmax
|
|
Activation.TANH, //l2
|
|
Activation.SOFTMAX, //l2 + softmax
|
|
Activation.IDENTITY, //mae
|
|
Activation.SOFTMAX, //mae + softmax
|
|
Activation.IDENTITY, //mape
|
|
Activation.SOFTMAX, //mape + softmax
|
|
Activation.SOFTMAX, //mcxent
|
|
Activation.IDENTITY, //mse
|
|
Activation.SOFTMAX, //mse + softmax
|
|
Activation.SIGMOID, //msle - requires positive labels/activations due to log
|
|
Activation.SOFTMAX, //msle + softmax
|
|
Activation.SIGMOID, //nll
|
|
Activation.SOFTMAX, //nll + softmax
|
|
Activation.SIGMOID, //poisson - requires positive predictions due to log... not sure if this is the best option
|
|
Activation.TANH, //squared hinge
|
|
Activation.SIGMOID, //f-measure (binary, single sigmoid output)
|
|
Activation.SIGMOID, //f-measure (binary, single sigmoid output)
|
|
Activation.SOFTMAX, //f-measure (binary, 2-label softmax output)
|
|
Activation.SOFTMAX, //f-measure (binary, 2-label softmax output)
|
|
Activation.IDENTITY, // MixtureDensity
|
|
Activation.TANH, // MixtureDensity + tanh
|
|
Activation.TANH, // MultiLabel
|
|
Activation.IDENTITY // Wasserstein
|
|
};
|
|
|
|
int[] nOut = new int[] {1, //xent
|
|
3, //xent
|
|
5, //cosine
|
|
3, //hinge
|
|
3, //kld
|
|
3, //kld + softmax
|
|
3, //l1
|
|
3, //l1 + softmax
|
|
3, //l2
|
|
3, //l2 + softmax
|
|
3, //mae
|
|
3, //mae + softmax
|
|
3, //mape
|
|
3, //mape + softmax
|
|
3, //mcxent
|
|
3, //mse
|
|
3, //mse + softmax
|
|
3, //msle
|
|
3, //msle + softmax
|
|
3, //nll
|
|
3, //nll + softmax
|
|
3, //poisson
|
|
3, //squared hinge
|
|
1, //f-measure (binary, single sigmoid output)
|
|
1, //f-measure (binary, single sigmoid output)
|
|
2, //f-measure (binary, 2-label softmax output)
|
|
2, //f-measure (binary, 2-label softmax output)
|
|
10, // Mixture Density
|
|
10, // Mixture Density + tanh
|
|
10, // MultiLabel
|
|
2, // Wasserstein
|
|
};
|
|
|
|
int[] minibatchSizes = new int[] {1, 3};
|
|
// int[] minibatchSizes = new int[]{3};
|
|
|
|
|
|
List<String> passed = new ArrayList<>();
|
|
List<String> failed = new ArrayList<>();
|
|
|
|
for (int i = 0; i < lossFunctions.length; i++) {
|
|
for (int j = 0; j < minibatchSizes.length; j++) {
|
|
String testName = lossFunctions[i] + " - " + outputActivationFn[i] + " - minibatchSize = "
|
|
+ minibatchSizes[j];
|
|
|
|
// Serialize and de-serialize loss function
|
|
// to ensure that we carry the parameters through
|
|
// the serializer.
|
|
try {
|
|
ObjectMapper m = NeuralNetConfiguration.mapper();
|
|
String s = m.writeValueAsString(lossFunctions[i]);
|
|
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
|
lossFunctions[i] = lf2;
|
|
} catch (IOException ex) {
|
|
ex.printStackTrace();
|
|
assertEquals("Tests failed: serialization of " + lossFunctions[i], 0, 1);
|
|
}
|
|
Nd4j.getRandom().setSeed(12345);
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
.dataType(DataType.DOUBLE)
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345)
|
|
.updater(new NoOp())
|
|
.dist(new UniformDistribution(-2, 2)).list()
|
|
.layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH)
|
|
.build())
|
|
.layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i])
|
|
.activation(outputActivationFn[i]).build())
|
|
.validateOutputLayerConfig(false)
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
assertTrue(((LossLayer) net.getLayer(1).conf().getLayer()).getLossFn().getClass() == lossFunctions[i]
|
|
.getClass());
|
|
|
|
INDArray[] inOut = getFeaturesAndLabels(lossFunctions[i], minibatchSizes[j], 4, nOut[i], 12345);
|
|
INDArray input = inOut[0];
|
|
INDArray labels = inOut[1];
|
|
|
|
log.info(" ***** Starting test: {} *****", testName);
|
|
// System.out.println(Arrays.toString(labels.data().asDouble()));
|
|
// System.out.println(Arrays.toString(net.output(input,false).data().asDouble()));
|
|
// System.out.println(net.score(new DataSet(input,labels)));
|
|
|
|
boolean gradOK;
|
|
try {
|
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
failed.add(testName + "\t" + "EXCEPTION");
|
|
continue;
|
|
}
|
|
|
|
if (gradOK) {
|
|
passed.add(testName);
|
|
} else {
|
|
failed.add(testName);
|
|
}
|
|
|
|
System.out.println("\n\n");
|
|
TestUtils.testModelSerialization(net);
|
|
}
|
|
}
|
|
|
|
|
|
System.out.println("---- Passed ----");
|
|
for (String s : passed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
System.out.println("---- Failed ----");
|
|
for (String s : failed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
assertEquals("Tests failed", 0, failed.size());
|
|
}
|
|
|
|
@Test
|
|
public void lossMultiLabelEdgeCases(){
|
|
INDArray labels;
|
|
Pair<Double, INDArray> gradientAndScore;
|
|
|
|
final ActivationIdentity activationFn = new ActivationIdentity();
|
|
final LossMultiLabel lossMultiLabel = new LossMultiLabel();
|
|
final INDArray preOutput = Nd4j.rand(3, 3);
|
|
|
|
// Base Case: Labels are NOT all 1 or 0
|
|
labels = Nd4j.diag(Nd4j.ones(3));
|
|
gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true);
|
|
|
|
assertTrue(!gradientAndScore.getFirst().isNaN());
|
|
assertTrue(!gradientAndScore.getFirst().isInfinite());
|
|
|
|
// Edge Case: Labels are all 1
|
|
labels = Nd4j.ones(3, 3);
|
|
gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true);
|
|
|
|
assertTrue(!gradientAndScore.getFirst().isNaN());
|
|
assertTrue(!gradientAndScore.getFirst().isInfinite());
|
|
|
|
// Edge Case: Labels are all 0
|
|
labels = Nd4j.zeros(3, 3);
|
|
gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true);
|
|
|
|
assertTrue(!gradientAndScore.getFirst().isNaN());
|
|
assertTrue(!gradientAndScore.getFirst().isInfinite());
|
|
}
|
|
|
|
public static INDArray[] getFeaturesAndLabels(ILossFunction l, long minibatch, long nIn, long nOut, long seed) {
|
|
return getFeaturesAndLabels(l, new long[] {minibatch, nIn}, new long[] {minibatch, nOut}, seed);
|
|
}
|
|
|
|
public static INDArray[] getFeaturesAndLabels(ILossFunction l, int[] featuresShape, int[] labelsShape, long seed) {
|
|
return getFeaturesAndLabels(l, ArrayUtil.toLongArray(featuresShape), ArrayUtil.toLongArray(labelsShape), seed);
|
|
}
|
|
|
|
public static INDArray[] getFeaturesAndLabels(ILossFunction l, long[] featuresShape, long[] labelsShape, long seed) {
|
|
Nd4j.getRandom().setSeed(seed);
|
|
Random r = new Random(seed);
|
|
INDArray[] ret = new INDArray[2];
|
|
|
|
ret[0] = Nd4j.rand(featuresShape);
|
|
|
|
switch (l.getClass().getSimpleName()) {
|
|
case "LossBinaryXENT":
|
|
//Want binary vector labels
|
|
ret[1] = Nd4j.rand(labelsShape);
|
|
BooleanIndexing.replaceWhere(ret[1], 0, Conditions.lessThanOrEqual(0.5));
|
|
BooleanIndexing.replaceWhere(ret[1], 1, Conditions.greaterThanOrEqual(0.5));
|
|
break;
|
|
case "LossCosineProximity":
|
|
//Should be real-valued??
|
|
ret[1] = Nd4j.rand(labelsShape).subi(0.5);
|
|
break;
|
|
case "LossKLD":
|
|
//KL divergence: should be a probability distribution for labels??
|
|
ret[1] = Nd4j.rand(labelsShape);
|
|
if(labelsShape.length == 2){
|
|
Nd4j.getExecutioner().exec(new SoftMax(ret[1]));
|
|
} else if(labelsShape.length == 3) {
|
|
for (int i = 0; i < labelsShape[2]; i++) {
|
|
Nd4j.getExecutioner().exec(new SoftMax(ret[1].get(all(), all(), point(i))));
|
|
}
|
|
} else {
|
|
throw new RuntimeException();
|
|
}
|
|
break;
|
|
case "LossMCXENT":
|
|
case "LossNegativeLogLikelihood":
|
|
ret[1] = Nd4j.zeros(labelsShape);
|
|
if (labelsShape.length == 2) {
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
// FIXME: int cast
|
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), 1.0);
|
|
}
|
|
} else if (labelsShape.length == 3) {
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
for (int j = 0; j < labelsShape[2]; j++) {
|
|
// FIXME: int cast
|
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, 1.0);
|
|
}
|
|
}
|
|
} else {
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
|
|
break;
|
|
case "LossHinge":
|
|
case "LossSquaredHinge":
|
|
ret[1] = Nd4j.ones(labelsShape);
|
|
if (labelsShape.length == 2) {
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
// FIXME: int cast
|
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), -1.0);
|
|
}
|
|
} else if (labelsShape.length == 3) {
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
for (int j = 0; j < labelsShape[2]; j++) {
|
|
// FIXME: int cast
|
|
ret[1].putScalar(i, r.nextInt((int) labelsShape[1]), j, -1.0);
|
|
}
|
|
}
|
|
} else {
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
break;
|
|
case "LossMAPE":
|
|
//requires non-zero values for actual...
|
|
ret[1] = Nd4j.rand(labelsShape).addi(1.0); //1 to 2
|
|
break;
|
|
case "LossMAE":
|
|
case "LossMSE":
|
|
case "LossL1":
|
|
case "LossL2":
|
|
ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1);
|
|
break;
|
|
case "LossMSLE":
|
|
//Requires positive labels/activations due to log
|
|
ret[1] = Nd4j.rand(labelsShape);
|
|
break;
|
|
case "LossPoisson":
|
|
//Binary vector labels should be OK here??
|
|
ret[1] = Nd4j.rand(labelsShape);
|
|
BooleanIndexing.replaceWhere(ret[1], 0, Conditions.lessThanOrEqual(0.5));
|
|
BooleanIndexing.replaceWhere(ret[1], 1, Conditions.greaterThanOrEqual(0.5));
|
|
break;
|
|
case "LossFMeasure":
|
|
if (labelsShape[1] == 1) {
|
|
//single binary output case
|
|
ret[1] = Nd4j.getExecutioner()
|
|
.exec(new BernoulliDistribution(Nd4j.createUninitialized(labelsShape), 0.5));
|
|
if (labelsShape[0] >= 2) {
|
|
//Ensure we have at least one "0" and one "1"
|
|
int count = ret[1].sumNumber().intValue();
|
|
if (count == 0) {
|
|
ret[1].putScalar(0, 0, 1.0);
|
|
} else if (count == ret[1].size(0)) {
|
|
ret[1].putScalar(0, 0, 0.0);
|
|
}
|
|
}
|
|
} else {
|
|
//"softmax style" binary output case
|
|
ret[1] = Nd4j.create(labelsShape);
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
ret[1].putScalar(i, i % labelsShape[1], 1.0);
|
|
}
|
|
}
|
|
break;
|
|
case "LossMixtureDensity":
|
|
LossMixtureDensity lmd = (LossMixtureDensity) l;
|
|
int labelWidth = lmd.getLabelWidth();
|
|
ret[1] = Nd4j.rand(new long[] {labelsShape[0], labelWidth});
|
|
break;
|
|
case "LossMultiLabel":
|
|
ret[1] = Nd4j.rand(labelsShape).lt(0.3).castTo(Nd4j.defaultFloatingPointType());
|
|
// ensure that there is no example that is all ones or all zeros
|
|
final INDArray sum = ret[1].sum(0);
|
|
for (int i = 0; i < labelsShape[0]; i++) {
|
|
final int rowSum = sum.getInt(i);
|
|
if (rowSum == 0) {
|
|
ret[1].putScalar(i, 0, 1);
|
|
} else if (rowSum == labelsShape[1]) {
|
|
ret[1].putScalar(i, 0, 0);
|
|
}
|
|
}
|
|
|
|
break;
|
|
case "LossWasserstein":
|
|
ret[1] = Nd4j.rand(labelsShape).mul(2).sub(1);
|
|
break;
|
|
|
|
default:
|
|
throw new IllegalArgumentException("Unknown class: " + l.getClass().getSimpleName());
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
|
|
@Test
|
|
public void lossFunctionWeightedGradientCheck() {
|
|
Nd4j.getRandom().setSeed(12345);
|
|
|
|
INDArray[] weights = new INDArray[] {Nd4j.create(new double[] {0.2, 0.3, 0.5}),
|
|
Nd4j.create(new double[] {1.0, 0.5, 2.0})};
|
|
|
|
|
|
List<String> passed = new ArrayList<>();
|
|
List<String> failed = new ArrayList<>();
|
|
|
|
for (INDArray w : weights) {
|
|
|
|
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(w), new LossL1(w), new LossL1(w),
|
|
new LossL2(w), new LossL2(w), new LossMAE(w), new LossMAE(w), new LossMAPE(w),
|
|
new LossMAPE(w), new LossMCXENT(w), new LossMSE(w), new LossMSE(w), new LossMSLE(w),
|
|
new LossMSLE(w), new LossNegativeLogLikelihood(w), new LossNegativeLogLikelihood(w),};
|
|
|
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
|
Activation.TANH, //l1
|
|
Activation.SOFTMAX, //l1 + softmax
|
|
Activation.TANH, //l2
|
|
Activation.SOFTMAX, //l2 + softmax
|
|
Activation.IDENTITY, //mae
|
|
Activation.SOFTMAX, //mae + softmax
|
|
Activation.IDENTITY, //mape
|
|
Activation.SOFTMAX, //mape + softmax
|
|
Activation.SOFTMAX, //mcxent
|
|
Activation.IDENTITY, //mse
|
|
Activation.SOFTMAX, //mse + softmax
|
|
Activation.SIGMOID, //msle - requires positive labels/activations due to log
|
|
Activation.SOFTMAX, //msle + softmax
|
|
Activation.SIGMOID, //nll
|
|
Activation.SOFTMAX, //nll + softmax
|
|
};
|
|
|
|
int[] minibatchSizes = new int[] {1, 3};
|
|
|
|
for (int i = 0; i < lossFunctions.length; i++) {
|
|
for (int j = 0; j < minibatchSizes.length; j++) {
|
|
String testName = lossFunctions[i] + " - " + outputActivationFn[i] + " - minibatchSize = "
|
|
+ minibatchSizes[j] + "; weights = " + w;
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
.dataType(DataType.DOUBLE)
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345)
|
|
.updater(new NoOp())
|
|
// .dist(new UniformDistribution(-3, 3))
|
|
.dist(new NormalDistribution(0, 1))
|
|
.list()
|
|
.layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH)
|
|
.build())
|
|
.layer(1, new OutputLayer.Builder().lossFunction(lossFunctions[i])
|
|
.activation(outputActivationFn[i]).nIn(4).nOut(3).build())
|
|
.validateOutputLayerConfig(false)
|
|
.build();
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
net.init();
|
|
|
|
//Check params to avoid test flakiness on small or large params
|
|
INDArray params = net.params();
|
|
for( int x=0; x<params.length(); x++ ){
|
|
while(Math.abs(params.getDouble(x)) < 0.01 || Math.abs(params.getDouble(x)) > 1.5){
|
|
double d = Nd4j.getRandom().nextDouble();
|
|
params.putScalar(x, -1.5 + d * 3);
|
|
}
|
|
}
|
|
|
|
INDArray[] inOut = getFeaturesAndLabels(lossFunctions[i], minibatchSizes[j], 4, 3, 12345);
|
|
INDArray input = inOut[0];
|
|
INDArray labels = inOut[1];
|
|
|
|
log.info(" ***** Starting test: {} *****", testName);
|
|
// System.out.println(Arrays.toString(labels.data().asDouble()));
|
|
// System.out.println(Arrays.toString(net.output(input,false).data().asDouble()));
|
|
// System.out.println(net.score(new DataSet(input,labels)));
|
|
|
|
boolean gradOK;
|
|
try {
|
|
gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels);
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
failed.add(testName + "\t" + "EXCEPTION");
|
|
continue;
|
|
}
|
|
|
|
if (gradOK) {
|
|
passed.add(testName);
|
|
} else {
|
|
failed.add(testName);
|
|
}
|
|
|
|
System.out.println("\n\n");
|
|
}
|
|
}
|
|
}
|
|
|
|
System.out.println("---- Passed ----");
|
|
for (String s : passed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
System.out.println("---- Failed ----");
|
|
for (String s : failed) {
|
|
System.out.println(s);
|
|
}
|
|
|
|
assertEquals("Tests failed", 0, failed.size());
|
|
}
|
|
}
|