DL4J: Add Sparse multi-class cross entropy loss function (#72)

* #8432 Add sparse mcxent loss

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes for LossSparseMCXENT

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* add simple debugging listener for SameDiff exec debugging

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Extra gradient check + header polishing

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-22 18:54:31 +11:00 committed by GitHub
parent 823bd0ff88
commit 4a2fedf3e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 279 additions and 52 deletions

View File

@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
Activation.IDENTITY // Wasserstein
Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX, //sparse MCXENT
};
int[] nOut = new int[] {1, //xent
@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh
10, // MultiLabel
2, // Wasserstein
4, //sparse MCXENT
};
int[] minibatchSizes = new int[] {1, 3};
@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein()
new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.IDENTITY, // MixtureDensity
Activation.TANH, // MixtureDensity + tanh
Activation.TANH, // MultiLabel
Activation.IDENTITY // Wasserstein
Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX
};
int[] nOut = new int[] {1, //xent
@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // Mixture Density + tanh
10, // MultiLabel
2, // Wasserstein
4
};
int[] minibatchSizes = new int[] {1, 3};
@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
throw new UnsupportedOperationException();
}
break;
case "LossSparseMCXENT":
if (labelsShape.length == 2) {
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1);
for (int i = 0; i < labelsShape[0]; i++) {
ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1]));
}
} else if (labelsShape.length == 3) {
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]);
for (int i = 0; i < labelsShape[0]; i++) {
for (int j = 0; j < labelsShape[2]; j++) {
ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1]));
}
}
} else {
throw new UnsupportedOperationException();
}
break;
case "LossHinge":
case "LossSquaredHinge":

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import java.util.Random;
@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
int nOut = 2;
int miniBatchSize = 3;
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()};
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()};
for (int maskType = 0; maskType < 3; maskType++) {
Random r = new Random(12345L);
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, idx, j}, 1.0);
}
}
INDArray labelMask;
String mt;
@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
break;
case 1:
//Per time step masking
labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength);
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerTimeStep";
break;
case 2:
//Per output masking:
labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength});
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength);
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
mt = "PerOutput";
break;
@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
for (ILossFunction lf : lfs) {
INDArray labels;
if(lf instanceof LossSparseMCXENT){
labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, 0, j}, idx);
}
}
} else {
labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
for (int i = 0; i < miniBatchSize; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int idx = r.nextInt(nOut);
labels.putScalar(new int[]{i, idx, j}, 1.0);
}
}
}
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
MultiLayerConfiguration conf =

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*;
public class LossFunctions {
/**
* MSE: Mean Squared Error: Linear Regression<br>
* EXPLL: Exponential log likelihood: Poisson Regression<br>
* XENT: Cross Entropy: Binary Classification<br>
* MCXENT: Multiclass Cross Entropy<br>
* RMSE_XENT: RMSE Cross Entropy<br>
* SQUARED_LOSS: Squared Loss<br>
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood<br>
* MSE: Mean Squared Error: Linear Regression - {@link LossMSE}<br>
* l1: L1 loss (absolute value) - {@link LossL1}<br>
* XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}<br>
* MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}<br>
* SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}<br>
* SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}<br>
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}<br>
* COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}<br>
* HINGE: Hinge loss - {@link LossHinge}<br>
* SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}<br>
* KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}<br>
* MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}<br>
* L2: L2 loss (sum of squared errors) - {@link LossL2}<br>
* MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}<br>
* MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}<br>
* POISSON: Poisson loss - {@link LossPoisson}<br>
* WASSERSTEIN: Wasserstein loss - {@link LossWasserstein}
*/
public enum LossFunction {
MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE,
SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
public ILossFunction getILossFunction() {
switch (this) {
@ -48,6 +60,8 @@ public class LossFunctions {
return new LossBinaryXENT();
case MCXENT:
return new LossMCXENT();
case SPARSE_MCXENT:
return new LossSparseMCXENT();
case KL_DIVERGENCE:
case RECONSTRUCTION_CROSSENTROPY:
return new LossKLD();
@ -68,7 +82,6 @@ public class LossFunctions {
case MEAN_SQUARED_LOGARITHMIC_ERROR:
return new LossMSLE();
case POISSON:
case EXPLL:
return new LossPoisson();
case WASSERSTEIN:
return new LossWasserstein();

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
/**
*
* Multi-Class Cross Entropy loss function:<br>
* L = sum_i actual_i * log( predicted_i )
* L = sum_i actual_i * log( predicted_i )<br>
* Note that labels are represented by a one-hot distribution<br>
* See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead
*
* @author Alex Black, Susan Eraly
* @see LossNegativeLogLikelihood
* @see LossSparseMCXENT
*/
@EqualsAndHashCode
@JsonInclude(JsonInclude.Include.NON_NULL)
@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction {
@JsonSerialize(using = NDArrayTextSerializer.class)
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
private INDArray weights;
protected INDArray weights;
private double softmaxClipEps;
protected double softmaxClipEps;
public LossMCXENT() {
this(null);
@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction {
this.softmaxClipEps = softmaxClipEps;
}
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if(!labels.equalShapes(preOutput)){
Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at

View File

@ -0,0 +1,133 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.lossfunctions.impl;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
/**
*
* Sparse Multi-Class Cross Entropy loss function:<br>
* L = sum_i actual_i * log( predicted_i )<br>
* Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels -
* this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent
* one-hot representation
*
* @author Alex Black
* @see LossNegativeLogLikelihood
* @see LossMCXENT
*/
@EqualsAndHashCode(callSuper = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
@Getter @Setter
public class LossSparseMCXENT extends LossMCXENT {
private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10;
public LossSparseMCXENT() {
this(null);
}
/**
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value.
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
* A weight vector of 1s should give identical results to no weight vector.
*
* @param weights Weights array (row vector). May be null.
*/
public LossSparseMCXENT(INDArray weights) {
this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights);
}
/**
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value.
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
* A weight vector of 1s should give identical results to no weight vector.
*
* @param weights Weights array (row vector). May be null.
*/
public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) {
super(softmaxClipEps, weights);
}
protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray oneHotLabels = toOneHot(labels, preOutput);
return super.scoreArray(oneHotLabels, preOutput, activationFn, mask);
}
@Override
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
boolean average) {
INDArray oneHotLabels = toOneHot(labels, preOutput);
return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average);
}
@Override
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask);
return scoreArr.sum(true,1).muli(-1);
}
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
INDArray oneHotLabels = toOneHot(labels, preOutput);
return super.computeGradient(oneHotLabels, preOutput, activationFn, mask);
}
@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
INDArray mask, boolean average) {
INDArray oneHotLabels = toOneHot(labels, preOutput);
return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average),
super.computeGradient(oneHotLabels, preOutput, activationFn, mask));
}
private INDArray toOneHot(INDArray labels, INDArray preOutput){
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
INDArray oneHotLabels = preOutput.ulike();
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
return oneHotLabels;
}
@Override
public String toString() {
if (weights == null)
return "LossSparseMCXENT()";
return "LossSparseMCXENT(weights=" + weights + ")";
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.imports.listeners.ExecPrintListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -137,7 +139,7 @@ public class TFGraphTestAllHelper {
protected static void checkOnlyOutput(Map<String, INDArray> inputs, Map<String, INDArray> predictions, String modelName,
String baseDir, String modelFilename, ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException {
Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException {
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
" must be null or both must be provided");
Nd4j.EPS_THRESHOLD = 1e-3;
@ -152,7 +154,7 @@ public class TFGraphTestAllHelper {
outputsToCheck.add(s);
}
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck);
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging);
SameDiff graph = p.getFirst();
Map<String,INDArray> sameDiffPredictions = p.getSecond();
@ -296,18 +298,18 @@ public class TFGraphTestAllHelper {
}
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
ExecuteWith execType, File localTestDir) throws IOException {
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir);
ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException {
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging);
}
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException {
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException {
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
" must be null or both must be provided");
Nd4j.EPS_THRESHOLD = 1e-3;
OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null);
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging);
SameDiff graph = p.getFirst();
Map<String,INDArray> sdPredictions = p.getSecond();
@ -388,13 +390,17 @@ public class TFGraphTestAllHelper {
public static Pair<SameDiff, Map<String,INDArray>> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map<String, INDArray> inputs,
ExecuteWith executeWith, BiFunction<File,String,SameDiff> graphLoaderFunction, List<Listener> listeners,
Set<String> requiredOutputs) throws IOException {
Set<String> requiredOutputs, boolean printArraysDebugging) throws IOException {
log.info("\n\tRUNNING TEST " + modelName + "...");
SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName);
if(listeners != null){
graph.setListeners(listeners);
}
if(printArraysDebugging){
graph.addListeners(new ExecPrintListener());
}
if(requiredOutputs == null){
requiredOutputs = graph.variableMap().keySet();
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -194,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
try {
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
} catch (Throwable t){
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -20,11 +21,9 @@ import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.NativeOpsHolder;
@ -51,9 +50,12 @@ public class TFGraphTestList {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
//Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step
//Implemented internally using ExecPrintListener
public static final boolean printArraysDebugging = false;
public static String[] modelNames = new String[]{
// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
"accumulate_n/rank0"
"resize_nearest_neighbor/int32"
};
@After
@ -102,7 +104,7 @@ public class TFGraphTestList {
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith,
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging);
}
@Test @Ignore
@ -110,7 +112,6 @@ public class TFGraphTestList {
//Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
File dir = testDir.newFolder();
Map<String, INDArray> inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir);
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir);
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging);
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
currentTestDir = testDir.newFolder();
log.info("----- SameDiff Exec: {} -----", modelName);
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF,
LOADER, maxRE, minAbs);
LOADER, maxRE, minAbs, false);
if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){
log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: ");

View File

@ -0,0 +1,29 @@
package org.nd4j.imports.listeners;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* A very quick and dirty debugging listener
* This listener just prints the outputs of any ops during execution
* @author Alex Black
*/
public class ExecPrintListener extends BaseListener {
@Override
public boolean isActive(Operation operation) {
return true;
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
for(INDArray arr : outputs){
System.out.println(arr);
}
}
}

View File

@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest {
@Test
public void testL2() {
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT,
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT,
LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY,
LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE,
LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,