diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
index 56d53d07a..fa06ff8f7 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
@@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(),
+ new LossSparseMCXENT()
};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
- Activation.IDENTITY // Wasserstein
+ Activation.IDENTITY, // Wasserstein
+ Activation.SOFTMAX, //sparse MCXENT
};
int[] nOut = new int[] {1, //xent
@@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh
10, // MultiLabel
2, // Wasserstein
+ 4, //sparse MCXENT
};
int[] minibatchSizes = new int[] {1, 3};
@@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
- new LossMultiLabel(), new LossWasserstein()
+ new LossMultiLabel(), new LossWasserstein(),
+ new LossSparseMCXENT()
};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel
- Activation.IDENTITY // Wasserstein
+ Activation.IDENTITY, // Wasserstein
+ Activation.SOFTMAX
};
int[] nOut = new int[] {1, //xent
@@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh
10, // MultiLabel
2, // Wasserstein
+ 4
};
int[] minibatchSizes = new int[] {1, 3};
@@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
throw new UnsupportedOperationException();
}
+ break;
+ case "LossSparseMCXENT":
+ if (labelsShape.length == 2) {
+ ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1);
+ for (int i = 0; i < labelsShape[0]; i++) {
+ ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1]));
+ }
+ } else if (labelsShape.length == 3) {
+ ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]);
+ for (int i = 0; i < labelsShape[0]; i++) {
+ for (int j = 0; j < labelsShape[2]; j++) {
+ ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1]));
+ }
+ }
+ } else {
+ throw new UnsupportedOperationException();
+ }
break;
case "LossHinge":
case "LossSquaredHinge":
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
index 2552b6072..32a229101 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java
@@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
+import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import java.util.Random;
@@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
int nOut = 2;
int miniBatchSize = 3;
- ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()};
+ ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()};
for (int maskType = 0; maskType < 3; maskType++) {
Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
- INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
- for (int i = 0; i < miniBatchSize; i++) {
- for (int j = 0; j < timeSeriesLength; j++) {
- int idx = r.nextInt(nOut);
- labels.putScalar(new int[]{i, idx, j}, 1.0);
- }
- }
INDArray labelMask;
String mt;
@@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
break;
case 1:
//Per time step masking
- labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength);
+ labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerTimeStep";
break;
case 2:
//Per output masking:
- labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength});
+ labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerOutput";
break;
@@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
for (ILossFunction lf : lfs) {
+ INDArray labels;
+ if(lf instanceof LossSparseMCXENT){
+ labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength);
+ for (int i = 0; i < miniBatchSize; i++) {
+ for (int j = 0; j < timeSeriesLength; j++) {
+ int idx = r.nextInt(nOut);
+ labels.putScalar(new int[]{i, 0, j}, idx);
+ }
+ }
+ } else {
+ labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
+ for (int i = 0; i < miniBatchSize; i++) {
+ for (int j = 0; j < timeSeriesLength; j++) {
+ int idx = r.nextInt(nOut);
+ labels.putScalar(new int[]{i, idx, j}, 1.0);
+ }
+ }
+ }
+
+
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
MultiLayerConfiguration conf =
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java
index 83f7f5f4c..1ff6c8dfe 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/LossFunctions.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*;
public class LossFunctions {
/**
- * MSE: Mean Squared Error: Linear Regression
- * EXPLL: Exponential log likelihood: Poisson Regression
- * XENT: Cross Entropy: Binary Classification
- * MCXENT: Multiclass Cross Entropy
- * RMSE_XENT: RMSE Cross Entropy
- * SQUARED_LOSS: Squared Loss
- * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood
+ * MSE: Mean Squared Error: Linear Regression - {@link LossMSE}
+ * l1: L1 loss (absolute value) - {@link LossL1}
+ * XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}
+ * MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}
+ * SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}
+ * SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}
+ * NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}
+ * COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}
+ * HINGE: Hinge loss - {@link LossHinge}
+ * SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}
+ * KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}
+ * MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}
+ * L2: L2 loss (sum of squared errors) - {@link LossL2}
+ * MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}
+ * MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}
+ * POISSON: Poisson loss - {@link LossPoisson}
+ * WASSERSTEIN: Wasserstein loss - {@link LossWasserstein}
*/
public enum LossFunction {
- MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
+ MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE,
+ SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
public ILossFunction getILossFunction() {
switch (this) {
@@ -48,6 +60,8 @@ public class LossFunctions {
return new LossBinaryXENT();
case MCXENT:
return new LossMCXENT();
+ case SPARSE_MCXENT:
+ return new LossSparseMCXENT();
case KL_DIVERGENCE:
case RECONSTRUCTION_CROSSENTROPY:
return new LossKLD();
@@ -68,7 +82,6 @@ public class LossFunctions {
case MEAN_SQUARED_LOGARITHMIC_ERROR:
return new LossMSLE();
case POISSON:
- case EXPLL:
return new LossPoisson();
case WASSERSTEIN:
return new LossWasserstein();
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
index a985734ff..22bb27e0e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
/**
*
* Multi-Class Cross Entropy loss function:
- * L = sum_i actual_i * log( predicted_i )
+ * L = sum_i actual_i * log( predicted_i )
+ * Note that labels are represented by a one-hot distribution
+ * See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead
*
* @author Alex Black, Susan Eraly
* @see LossNegativeLogLikelihood
+ * @see LossSparseMCXENT
*/
@EqualsAndHashCode
@JsonInclude(JsonInclude.Include.NON_NULL)
@@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction {
@JsonSerialize(using = NDArrayTextSerializer.class)
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
- private INDArray weights;
+ protected INDArray weights;
- private double softmaxClipEps;
+ protected double softmaxClipEps;
public LossMCXENT() {
this(null);
@@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction {
this.softmaxClipEps = softmaxClipEps;
}
- private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
+ protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if(!labels.equalShapes(preOutput)){
Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java
index 162ac0797..a8453cf9c 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossNegativeLogLikelihood.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java
new file mode 100644
index 000000000..2ea0feb52
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java
@@ -0,0 +1,133 @@
+/* ******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.nd4j.linalg.lossfunctions.impl;
+
+
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.Setter;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.activations.impl.ActivationSoftmax;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.shape.OneHot;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.indexing.BooleanIndexing;
+import org.nd4j.linalg.indexing.conditions.Conditions;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
+import org.nd4j.linalg.lossfunctions.LossUtil;
+import org.nd4j.linalg.ops.transforms.Transforms;
+import org.nd4j.linalg.primitives.Pair;
+import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
+import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
+import org.nd4j.shade.jackson.annotation.JsonInclude;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
+import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
+import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
+
+/**
+ *
+ * Sparse Multi-Class Cross Entropy loss function:
+ * L = sum_i actual_i * log( predicted_i )
+ * Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels -
+ * this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent
+ * one-hot representation
+ *
+ * @author Alex Black
+ * @see LossNegativeLogLikelihood
+ * @see LossMCXENT
+ */
+@EqualsAndHashCode(callSuper = true)
+@JsonInclude(JsonInclude.Include.NON_NULL)
+@Getter @Setter
+public class LossSparseMCXENT extends LossMCXENT {
+ private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10;
+
+ public LossSparseMCXENT() {
+ this(null);
+ }
+
+ /**
+ * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value.
+ * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
+ * A weight vector of 1s should give identical results to no weight vector.
+ *
+ * @param weights Weights array (row vector). May be null.
+ */
+ public LossSparseMCXENT(INDArray weights) {
+ this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights);
+ }
+
+ /**
+ * Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value.
+ * Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
+ * A weight vector of 1s should give identical results to no weight vector.
+ *
+ * @param weights Weights array (row vector). May be null.
+ */
+ public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) {
+ super(softmaxClipEps, weights);
+ }
+
+ protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
+ INDArray oneHotLabels = toOneHot(labels, preOutput);
+ return super.scoreArray(oneHotLabels, preOutput, activationFn, mask);
+ }
+
+ @Override
+ public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
+ boolean average) {
+ INDArray oneHotLabels = toOneHot(labels, preOutput);
+ return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average);
+ }
+
+ @Override
+ public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
+ INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask);
+ return scoreArr.sum(true,1).muli(-1);
+ }
+
+ @Override
+ public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
+ INDArray oneHotLabels = toOneHot(labels, preOutput);
+ return super.computeGradient(oneHotLabels, preOutput, activationFn, mask);
+ }
+
+ @Override
+ public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
+ INDArray mask, boolean average) {
+ INDArray oneHotLabels = toOneHot(labels, preOutput);
+ return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average),
+ super.computeGradient(oneHotLabels, preOutput, activationFn, mask));
+ }
+
+ private INDArray toOneHot(INDArray labels, INDArray preOutput){
+ Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
+ "with last dimension having size 1. Got labels array with shape %ndShape", labels);
+ INDArray oneHotLabels = preOutput.ulike();
+ Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
+ return oneHotLabels;
+ }
+
+
+ @Override
+ public String toString() {
+ if (weights == null)
+ return "LossSparseMCXENT()";
+ return "LossSparseMCXENT(weights=" + weights + ")";
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java
index 091e5cad1..83145a048 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
+import org.nd4j.imports.listeners.ExecPrintListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -137,7 +139,7 @@ public class TFGraphTestAllHelper {
protected static void checkOnlyOutput(Map inputs, Map predictions, String modelName,
String baseDir, String modelFilename, ExecuteWith execType, BiFunction loader,
- Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException {
+ Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException {
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
" must be null or both must be provided");
Nd4j.EPS_THRESHOLD = 1e-3;
@@ -152,7 +154,7 @@ public class TFGraphTestAllHelper {
outputsToCheck.add(s);
}
- Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck);
+ Pair> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging);
SameDiff graph = p.getFirst();
Map sameDiffPredictions = p.getSecond();
@@ -296,18 +298,18 @@ public class TFGraphTestAllHelper {
}
public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName,
- ExecuteWith execType, File localTestDir) throws IOException {
- checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir);
+ ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException {
+ checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging);
}
public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName,
ExecuteWith execType, BiFunction loader,
- Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException {
+ Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException {
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
" must be null or both must be provided");
Nd4j.EPS_THRESHOLD = 1e-3;
OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order
- Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null);
+ Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging);
SameDiff graph = p.getFirst();
Map sdPredictions = p.getSecond();
@@ -388,13 +390,17 @@ public class TFGraphTestAllHelper {
public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs,
ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners,
- Set requiredOutputs) throws IOException {
+ Set requiredOutputs, boolean printArraysDebugging) throws IOException {
log.info("\n\tRUNNING TEST " + modelName + "...");
SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName);
if(listeners != null){
graph.setListeners(listeners);
}
+ if(printArraysDebugging){
+ graph.addListeners(new ExecPrintListener());
+ }
+
if(requiredOutputs == null){
requiredOutputs = graph.variableMap().keySet();
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java
index df7f4726a..a9870ea82 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
- TFGraphTestAllHelper.LOADER, maxRE, minAbs);
+ TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH);
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
index db339b46e..a690bc5a8 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -194,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
try {
- TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
+ TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
} catch (Throwable t){
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
index 3a39dac37..b0b344f64 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -20,11 +21,9 @@ import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
-import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOpsHolder;
@@ -51,9 +50,12 @@ public class TFGraphTestList {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
+ //Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step
+ //Implemented internally using ExecPrintListener
+ public static final boolean printArraysDebugging = false;
+
public static String[] modelNames = new String[]{
-// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
- "accumulate_n/rank0"
+ "resize_nearest_neighbor/int32"
};
@After
@@ -102,7 +104,7 @@ public class TFGraphTestList {
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith,
- TFGraphTestAllHelper.LOADER, maxRE, minAbs);
+ TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging);
}
@Test @Ignore
@@ -110,7 +112,6 @@ public class TFGraphTestList {
//Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
File dir = testDir.newFolder();
Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir);
- TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir);
-
+ TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging);
}
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java
index 05edef2b8..d08fb5148 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
currentTestDir = testDir.newFolder();
log.info("----- SameDiff Exec: {} -----", modelName);
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF,
- LOADER, maxRE, minAbs);
+ LOADER, maxRE, minAbs, false);
if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){
log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: ");
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java
new file mode 100644
index 000000000..6eff6b157
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listeners/ExecPrintListener.java
@@ -0,0 +1,29 @@
+package org.nd4j.imports.listeners;
+
+import org.nd4j.autodiff.listeners.At;
+import org.nd4j.autodiff.listeners.BaseListener;
+import org.nd4j.autodiff.listeners.Operation;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.autodiff.samediff.internal.SameDiffOp;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.dataset.api.MultiDataSet;
+
+/**
+ * A very quick and dirty debugging listener
+ * This listener just prints the outputs of any ops during execution
+ * @author Alex Black
+ */
+public class ExecPrintListener extends BaseListener {
+ @Override
+ public boolean isActive(Operation operation) {
+ return true;
+ }
+
+ @Override
+ public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
+ System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
+ for(INDArray arr : outputs){
+ System.out.println(arr);
+ }
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java
index d01200db8..0e197f6cc 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/TestLossFunctionsSizeChecks.java
@@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest {
@Test
public void testL2() {
- LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT,
+ LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT,
LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY,
LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE,
LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,