From 4a2fedf3e7115a87b95e7ae30bc98916445333fa Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 22 Nov 2019 18:54:31 +1100 Subject: [PATCH] DL4J: Add Sparse multi-class cross entropy loss function (#72) * #8432 Add sparse mcxent loss Signed-off-by: AlexDBlack * Fixes for LossSparseMCXENT Signed-off-by: AlexDBlack * add simple debugging listener for SameDiff exec debugging Signed-off-by: AlexDBlack * Extra gradient check + header polishing Signed-off-by: AlexDBlack --- .../LossFunctionGradientCheck.java | 29 +++- .../OutputLayerGradientChecks.java | 34 +++-- .../linalg/lossfunctions/LossFunctions.java | 33 +++-- .../linalg/lossfunctions/impl/LossMCXENT.java | 14 +- .../impl/LossNegativeLogLikelihood.java | 3 +- .../lossfunctions/impl/LossSparseMCXENT.java | 133 ++++++++++++++++++ .../TFGraphs/TFGraphTestAllHelper.java | 22 +-- .../TFGraphs/TFGraphTestAllLibnd4j.java | 5 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 5 +- .../imports/TFGraphs/TFGraphTestList.java | 17 +-- .../TFGraphs/TFGraphTestZooModels.java | 5 +- .../imports/listeners/ExecPrintListener.java | 29 ++++ .../TestLossFunctionsSizeChecks.java | 2 +- 13 files changed, 279 insertions(+), 52 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 56d53d07a..fa06ff8f7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossMultiLabel(), new LossWasserstein(), + new LossSparseMCXENT() }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { Activation.IDENTITY, // MixtureDensity Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper - Activation.IDENTITY // Wasserstein + Activation.IDENTITY, // Wasserstein + Activation.SOFTMAX, //sparse MCXENT }; int[] nOut = new int[] {1, //xent @@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { 10, // Mixture Density + tanh 10, // MultiLabel 2, // Wasserstein + 4, //sparse MCXENT }; int[] minibatchSizes = new int[] {1, 3}; @@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(), new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), - new LossMultiLabel(), new LossWasserstein() + new LossMultiLabel(), new LossWasserstein(), + new LossSparseMCXENT() }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { Activation.IDENTITY, // MixtureDensity Activation.TANH, // MixtureDensity + tanh Activation.TANH, // MultiLabel - Activation.IDENTITY // Wasserstein + Activation.IDENTITY, // Wasserstein + Activation.SOFTMAX }; int[] nOut = new int[] {1, //xent @@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { 10, // Mixture Density + tanh 10, // MultiLabel 2, // Wasserstein + 4 }; int[] minibatchSizes = new int[] {1, 3}; @@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { throw new UnsupportedOperationException(); } + break; + case "LossSparseMCXENT": + if (labelsShape.length == 2) { + ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1); + for (int i = 0; i < labelsShape[0]; i++) { + ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1])); + } + } else if (labelsShape.length == 3) { + ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]); + for (int i = 0; i < labelsShape[0]; i++) { + for (int j = 0; j < labelsShape[2]; j++) { + ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1])); + } + } + } else { + throw new UnsupportedOperationException(); + } break; case "LossHinge": case "LossSquaredHinge": diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 2552b6072..32a229101 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; import org.nd4j.linalg.lossfunctions.impl.LossMSE; +import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT; import java.util.Random; @@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { int nOut = 2; int miniBatchSize = 3; - ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()}; + ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()}; for (int maskType = 0; maskType < 3; maskType++) { Random r = new Random(12345L); INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength); - for (int i = 0; i < miniBatchSize; i++) { - for (int j = 0; j < timeSeriesLength; j++) { - int idx = r.nextInt(nOut); - labels.putScalar(new int[]{i, idx, j}, 1.0); - } - } INDArray labelMask; String mt; @@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { break; case 1: //Per time step masking - labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength); + labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerTimeStep"; break; case 2: //Per output masking: - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength}); + labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerOutput"; break; @@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { for (ILossFunction lf : lfs) { + INDArray labels; + if(lf instanceof LossSparseMCXENT){ + labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength); + for (int i = 0; i < miniBatchSize; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int idx = r.nextInt(nOut); + labels.putScalar(new int[]{i, 0, j}, idx); + } + } + } else { + labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength); + for (int i = 0; i < miniBatchSize; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int idx = r.nextInt(nOut); + labels.putScalar(new int[]{i, idx, j}, 1.0); + } + } + } + + Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX; MultiLayerConfiguration conf = diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java index 83f7f5f4c..1ff6c8dfe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*; public class LossFunctions { /** - * MSE: Mean Squared Error: Linear Regression
- * EXPLL: Exponential log likelihood: Poisson Regression
- * XENT: Cross Entropy: Binary Classification
- * MCXENT: Multiclass Cross Entropy
- * RMSE_XENT: RMSE Cross Entropy
- * SQUARED_LOSS: Squared Loss
- * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood
+ * MSE: Mean Squared Error: Linear Regression - {@link LossMSE}
+ * l1: L1 loss (absolute value) - {@link LossL1}
+ * XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}
+ * MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}
+ * SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}
+ * SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}
+ * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}
+ * COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}
+ * HINGE: Hinge loss - {@link LossHinge}
+ * SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}
+ * KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}
+ * MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}
+ * L2: L2 loss (sum of squared errors) - {@link LossL2}
+ * MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}
+ * MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}
+ * POISSON: Poisson loss - {@link LossPoisson}
+ * WASSERSTEIN: Wasserstein loss - {@link LossWasserstein} */ public enum LossFunction { - MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN; + MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE, + SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN; public ILossFunction getILossFunction() { switch (this) { @@ -48,6 +60,8 @@ public class LossFunctions { return new LossBinaryXENT(); case MCXENT: return new LossMCXENT(); + case SPARSE_MCXENT: + return new LossSparseMCXENT(); case KL_DIVERGENCE: case RECONSTRUCTION_CROSSENTROPY: return new LossKLD(); @@ -68,7 +82,6 @@ public class LossFunctions { case MEAN_SQUARED_LOGARITHMIC_ERROR: return new LossMSLE(); case POISSON: - case EXPLL: return new LossPoisson(); case WASSERSTEIN: return new LossWasserstein(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java index a985734ff..22bb27e0e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; /** * * Multi-Class Cross Entropy loss function:
- * L = sum_i actual_i * log( predicted_i ) + * L = sum_i actual_i * log( predicted_i )
+ * Note that labels are represented by a one-hot distribution
+ * See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead * * @author Alex Black, Susan Eraly * @see LossNegativeLogLikelihood + * @see LossSparseMCXENT */ @EqualsAndHashCode @JsonInclude(JsonInclude.Include.NON_NULL) @@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction { @JsonSerialize(using = NDArrayTextSerializer.class) @JsonDeserialize(using = NDArrayTextDeSerializer.class) - private INDArray weights; + protected INDArray weights; - private double softmaxClipEps; + protected double softmaxClipEps; public LossMCXENT() { this(null); @@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction { this.softmaxClipEps = softmaxClipEps; } - private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { if(!labels.equalShapes(preOutput)){ Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java index 162ac0797..a8453cf9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java new file mode 100644 index 000000000..2ea0feb52 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java @@ -0,0 +1,133 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.lossfunctions.impl; + + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.shape.OneHot; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.BooleanIndexing; +import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.LossUtil; +import org.nd4j.linalg.ops.transforms.Transforms; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; +import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; +import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonProperty; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +/** + * + * Sparse Multi-Class Cross Entropy loss function:
+ * L = sum_i actual_i * log( predicted_i )
+ * Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels - + * this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent + * one-hot representation + * + * @author Alex Black + * @see LossNegativeLogLikelihood + * @see LossMCXENT + */ +@EqualsAndHashCode(callSuper = true) +@JsonInclude(JsonInclude.Include.NON_NULL) +@Getter @Setter +public class LossSparseMCXENT extends LossMCXENT { + private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10; + + public LossSparseMCXENT() { + this(null); + } + + /** + * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value. + * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size. + * A weight vector of 1s should give identical results to no weight vector. + * + * @param weights Weights array (row vector). May be null. + */ + public LossSparseMCXENT(INDArray weights) { + this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights); + } + + /** + * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value. + * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size. + * A weight vector of 1s should give identical results to no weight vector. + * + * @param weights Weights array (row vector). May be null. + */ + public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) { + super(softmaxClipEps, weights); + } + + protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.scoreArray(oneHotLabels, preOutput, activationFn, mask); + } + + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, + boolean average) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average); + } + + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask); + return scoreArr.sum(true,1).muli(-1); + } + + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return super.computeGradient(oneHotLabels, preOutput, activationFn, mask); + } + + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, + INDArray mask, boolean average) { + INDArray oneHotLabels = toOneHot(labels, preOutput); + return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average), + super.computeGradient(oneHotLabels, preOutput, activationFn, mask)); + } + + private INDArray toOneHot(INDArray labels, INDArray preOutput){ + Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " + + "with last dimension having size 1. Got labels array with shape %ndShape", labels); + INDArray oneHotLabels = preOutput.ulike(); + Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1))); + return oneHotLabels; + } + + + @Override + public String toString() { + if (weights == null) + return "LossSparseMCXENT()"; + return "LossSparseMCXENT(weights=" + weights + ")"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 091e5cad1..83145a048 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.base.Preconditions; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.imports.listeners.ExecPrintListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; @@ -137,7 +139,7 @@ public class TFGraphTestAllHelper { protected static void checkOnlyOutput(Map inputs, Map predictions, String modelName, String baseDir, String modelFilename, ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException { + Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException { Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; @@ -152,7 +154,7 @@ public class TFGraphTestAllHelper { outputsToCheck.add(s); } - Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck); + Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging); SameDiff graph = p.getFirst(); Map sameDiffPredictions = p.getSecond(); @@ -296,18 +298,18 @@ public class TFGraphTestAllHelper { } public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, File localTestDir) throws IOException { - checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir); + ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException { + checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging); } public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException { + Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException { Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + " must be null or both must be provided"); Nd4j.EPS_THRESHOLD = 1e-3; OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null); + Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging); SameDiff graph = p.getFirst(); Map sdPredictions = p.getSecond(); @@ -388,13 +390,17 @@ public class TFGraphTestAllHelper { public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, - Set requiredOutputs) throws IOException { + Set requiredOutputs, boolean printArraysDebugging) throws IOException { log.info("\n\tRUNNING TEST " + modelName + "..."); SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); if(listeners != null){ graph.setListeners(listeners); } + if(printArraysDebugging){ + graph.addListeners(new ExecPrintListener()); + } + if(requiredOutputs == null){ requiredOutputs = graph.variableMap().keySet(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java index df7f4726a..a9870ea82 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index db339b46e..a690bc5a8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -194,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); try { - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); } catch (Throwable t){ log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 3a39dac37..b0b344f64 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,11 +21,9 @@ import org.junit.*; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOpsHolder; @@ -51,9 +50,12 @@ public class TFGraphTestList { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + //Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step + //Implemented internally using ExecPrintListener + public static final boolean printArraysDebugging = false; + public static String[] modelNames = new String[]{ -// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" - "accumulate_n/rank0" + "resize_nearest_neighbor/int32" }; @After @@ -102,7 +104,7 @@ public class TFGraphTestList { Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging); } @Test @Ignore @@ -110,7 +112,6 @@ public class TFGraphTestList { //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); File dir = testDir.newFolder(); Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir); - + TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index 05edef2b8..d08fb5148 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we currentTestDir = testDir.newFolder(); log.info("----- SameDiff Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, - LOADER, maxRE, minAbs); + LOADER, maxRE, minAbs, false); if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){ log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: "); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java new file mode 100644 index 000000000..6eff6b157 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java @@ -0,0 +1,29 @@ +package org.nd4j.imports.listeners; + +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * A very quick and dirty debugging listener + * This listener just prints the outputs of any ops during execution + * @author Alex Black + */ +public class ExecPrintListener extends BaseListener { + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------"); + for(INDArray arr : outputs){ + System.out.println(arr); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java index d01200db8..0e197f6cc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java @@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest { @Test public void testL2() { - LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT, + LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT, LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY, LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE, LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,