DL4J: Add Sparse multi-class cross entropy loss function (#72)
* #8432 Add sparse mcxent loss Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for LossSparseMCXENT Signed-off-by: AlexDBlack <blacka101@gmail.com> * add simple debugging listener for SameDiff exec debugging Signed-off-by: AlexDBlack <blacka101@gmail.com> * Extra gradient check + header polishing Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
823bd0ff88
commit
4a2fedf3e7
|
@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
new LossMultiLabel(), new LossWasserstein(),
|
||||
new LossSparseMCXENT()
|
||||
};
|
||||
|
||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||
|
@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
Activation.IDENTITY, // MixtureDensity
|
||||
Activation.TANH, // MixtureDensity + tanh
|
||||
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
||||
Activation.IDENTITY // Wasserstein
|
||||
Activation.IDENTITY, // Wasserstein
|
||||
Activation.SOFTMAX, //sparse MCXENT
|
||||
};
|
||||
|
||||
int[] nOut = new int[] {1, //xent
|
||||
|
@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
10, // Mixture Density + tanh
|
||||
10, // MultiLabel
|
||||
2, // Wasserstein
|
||||
4, //sparse MCXENT
|
||||
};
|
||||
|
||||
int[] minibatchSizes = new int[] {1, 3};
|
||||
|
@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
|
||||
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
new LossMultiLabel(), new LossWasserstein()
|
||||
new LossMultiLabel(), new LossWasserstein(),
|
||||
new LossSparseMCXENT()
|
||||
};
|
||||
|
||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||
|
@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
Activation.IDENTITY, // MixtureDensity
|
||||
Activation.TANH, // MixtureDensity + tanh
|
||||
Activation.TANH, // MultiLabel
|
||||
Activation.IDENTITY // Wasserstein
|
||||
Activation.IDENTITY, // Wasserstein
|
||||
Activation.SOFTMAX
|
||||
};
|
||||
|
||||
int[] nOut = new int[] {1, //xent
|
||||
|
@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
10, // Mixture Density + tanh
|
||||
10, // MultiLabel
|
||||
2, // Wasserstein
|
||||
4
|
||||
};
|
||||
|
||||
int[] minibatchSizes = new int[] {1, 3};
|
||||
|
@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
break;
|
||||
case "LossSparseMCXENT":
|
||||
if (labelsShape.length == 2) {
|
||||
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1);
|
||||
for (int i = 0; i < labelsShape[0]; i++) {
|
||||
ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1]));
|
||||
}
|
||||
} else if (labelsShape.length == 3) {
|
||||
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]);
|
||||
for (int i = 0; i < labelsShape[0]; i++) {
|
||||
for (int j = 0; j < labelsShape[2]; j++) {
|
||||
ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
break;
|
||||
case "LossHinge":
|
||||
case "LossSquaredHinge":
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp;
|
|||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
|
@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
|||
int nOut = 2;
|
||||
int miniBatchSize = 3;
|
||||
|
||||
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()};
|
||||
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()};
|
||||
|
||||
for (int maskType = 0; maskType < 3; maskType++) {
|
||||
|
||||
Random r = new Random(12345L);
|
||||
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
|
||||
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
||||
for (int i = 0; i < miniBatchSize; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
int idx = r.nextInt(nOut);
|
||||
labels.putScalar(new int[]{i, idx, j}, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray labelMask;
|
||||
String mt;
|
||||
|
@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
|||
break;
|
||||
case 1:
|
||||
//Per time step masking
|
||||
labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength);
|
||||
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength);
|
||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
||||
mt = "PerTimeStep";
|
||||
break;
|
||||
case 2:
|
||||
//Per output masking:
|
||||
labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength});
|
||||
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength);
|
||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
||||
mt = "PerOutput";
|
||||
break;
|
||||
|
@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
|||
|
||||
for (ILossFunction lf : lfs) {
|
||||
|
||||
INDArray labels;
|
||||
if(lf instanceof LossSparseMCXENT){
|
||||
labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength);
|
||||
for (int i = 0; i < miniBatchSize; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
int idx = r.nextInt(nOut);
|
||||
labels.putScalar(new int[]{i, 0, j}, idx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
||||
for (int i = 0; i < miniBatchSize; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
int idx = r.nextInt(nOut);
|
||||
labels.putScalar(new int[]{i, idx, j}, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
|
||||
|
||||
MultiLayerConfiguration conf =
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*;
|
|||
public class LossFunctions {
|
||||
|
||||
/**
|
||||
* MSE: Mean Squared Error: Linear Regression<br>
|
||||
* EXPLL: Exponential log likelihood: Poisson Regression<br>
|
||||
* XENT: Cross Entropy: Binary Classification<br>
|
||||
* MCXENT: Multiclass Cross Entropy<br>
|
||||
* RMSE_XENT: RMSE Cross Entropy<br>
|
||||
* SQUARED_LOSS: Squared Loss<br>
|
||||
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood<br>
|
||||
* MSE: Mean Squared Error: Linear Regression - {@link LossMSE}<br>
|
||||
* l1: L1 loss (absolute value) - {@link LossL1}<br>
|
||||
* XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}<br>
|
||||
* MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}<br>
|
||||
* SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}<br>
|
||||
* SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}<br>
|
||||
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}<br>
|
||||
* COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}<br>
|
||||
* HINGE: Hinge loss - {@link LossHinge}<br>
|
||||
* SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}<br>
|
||||
* KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}<br>
|
||||
* MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}<br>
|
||||
* L2: L2 loss (sum of squared errors) - {@link LossL2}<br>
|
||||
* MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}<br>
|
||||
* MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}<br>
|
||||
* POISSON: Poisson loss - {@link LossPoisson}<br>
|
||||
* WASSERSTEIN: Wasserstein loss - {@link LossWasserstein}
|
||||
*/
|
||||
public enum LossFunction {
|
||||
MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
|
||||
MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE,
|
||||
SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
|
||||
|
||||
public ILossFunction getILossFunction() {
|
||||
switch (this) {
|
||||
|
@ -48,6 +60,8 @@ public class LossFunctions {
|
|||
return new LossBinaryXENT();
|
||||
case MCXENT:
|
||||
return new LossMCXENT();
|
||||
case SPARSE_MCXENT:
|
||||
return new LossSparseMCXENT();
|
||||
case KL_DIVERGENCE:
|
||||
case RECONSTRUCTION_CROSSENTROPY:
|
||||
return new LossKLD();
|
||||
|
@ -68,7 +82,6 @@ public class LossFunctions {
|
|||
case MEAN_SQUARED_LOGARITHMIC_ERROR:
|
||||
return new LossMSLE();
|
||||
case POISSON:
|
||||
case EXPLL:
|
||||
return new LossPoisson();
|
||||
case WASSERSTEIN:
|
||||
return new LossWasserstein();
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|||
/**
|
||||
*
|
||||
* Multi-Class Cross Entropy loss function:<br>
|
||||
* L = sum_i actual_i * log( predicted_i )
|
||||
* L = sum_i actual_i * log( predicted_i )<br>
|
||||
* Note that labels are represented by a one-hot distribution<br>
|
||||
* See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead
|
||||
*
|
||||
* @author Alex Black, Susan Eraly
|
||||
* @see LossNegativeLogLikelihood
|
||||
* @see LossSparseMCXENT
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
|
@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction {
|
|||
|
||||
@JsonSerialize(using = NDArrayTextSerializer.class)
|
||||
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
|
||||
private INDArray weights;
|
||||
protected INDArray weights;
|
||||
|
||||
private double softmaxClipEps;
|
||||
protected double softmaxClipEps;
|
||||
|
||||
public LossMCXENT() {
|
||||
this(null);
|
||||
|
@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction {
|
|||
this.softmaxClipEps = softmaxClipEps;
|
||||
}
|
||||
|
||||
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
if(!labels.equalShapes(preOutput)){
|
||||
Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.lossfunctions.impl;
|
||||
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossUtil;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
|
||||
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
/**
|
||||
*
|
||||
* Sparse Multi-Class Cross Entropy loss function:<br>
|
||||
* L = sum_i actual_i * log( predicted_i )<br>
|
||||
* Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels -
|
||||
* this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent
|
||||
* one-hot representation
|
||||
*
|
||||
* @author Alex Black
|
||||
* @see LossNegativeLogLikelihood
|
||||
* @see LossMCXENT
|
||||
*/
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@Getter @Setter
|
||||
public class LossSparseMCXENT extends LossMCXENT {
|
||||
private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10;
|
||||
|
||||
public LossSparseMCXENT() {
|
||||
this(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value.
|
||||
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
|
||||
* A weight vector of 1s should give identical results to no weight vector.
|
||||
*
|
||||
* @param weights Weights array (row vector). May be null.
|
||||
*/
|
||||
public LossSparseMCXENT(INDArray weights) {
|
||||
this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights);
|
||||
}
|
||||
|
||||
/**
|
||||
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value.
|
||||
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
|
||||
* A weight vector of 1s should give identical results to no weight vector.
|
||||
*
|
||||
* @param weights Weights array (row vector). May be null.
|
||||
*/
|
||||
public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) {
|
||||
super(softmaxClipEps, weights);
|
||||
}
|
||||
|
||||
protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||
return super.scoreArray(oneHotLabels, preOutput, activationFn, mask);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
|
||||
boolean average) {
|
||||
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||
return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask);
|
||||
return scoreArr.sum(true,1).muli(-1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||
return super.computeGradient(oneHotLabels, preOutput, activationFn, mask);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
|
||||
INDArray mask, boolean average) {
|
||||
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||
return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average),
|
||||
super.computeGradient(oneHotLabels, preOutput, activationFn, mask));
|
||||
}
|
||||
|
||||
private INDArray toOneHot(INDArray labels, INDArray preOutput){
|
||||
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
|
||||
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
|
||||
INDArray oneHotLabels = preOutput.ulike();
|
||||
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
|
||||
return oneHotLabels;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
if (weights == null)
|
||||
return "LossSparseMCXENT()";
|
||||
return "LossSparseMCXENT(weights=" + weights + ")";
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.imports.listeners.ExecPrintListener;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -137,7 +139,7 @@ public class TFGraphTestAllHelper {
|
|||
|
||||
protected static void checkOnlyOutput(Map<String, INDArray> inputs, Map<String, INDArray> predictions, String modelName,
|
||||
String baseDir, String modelFilename, ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
||||
Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException {
|
||||
Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException {
|
||||
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
||||
" must be null or both must be provided");
|
||||
Nd4j.EPS_THRESHOLD = 1e-3;
|
||||
|
@ -152,7 +154,7 @@ public class TFGraphTestAllHelper {
|
|||
outputsToCheck.add(s);
|
||||
}
|
||||
|
||||
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck);
|
||||
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging);
|
||||
SameDiff graph = p.getFirst();
|
||||
Map<String,INDArray> sameDiffPredictions = p.getSecond();
|
||||
|
||||
|
@ -296,18 +298,18 @@ public class TFGraphTestAllHelper {
|
|||
}
|
||||
|
||||
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
||||
ExecuteWith execType, File localTestDir) throws IOException {
|
||||
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir);
|
||||
ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException {
|
||||
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging);
|
||||
}
|
||||
|
||||
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
||||
ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
||||
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException {
|
||||
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException {
|
||||
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
||||
" must be null or both must be provided");
|
||||
Nd4j.EPS_THRESHOLD = 1e-3;
|
||||
OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order
|
||||
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null);
|
||||
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging);
|
||||
SameDiff graph = p.getFirst();
|
||||
Map<String,INDArray> sdPredictions = p.getSecond();
|
||||
|
||||
|
@ -388,13 +390,17 @@ public class TFGraphTestAllHelper {
|
|||
|
||||
public static Pair<SameDiff, Map<String,INDArray>> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map<String, INDArray> inputs,
|
||||
ExecuteWith executeWith, BiFunction<File,String,SameDiff> graphLoaderFunction, List<Listener> listeners,
|
||||
Set<String> requiredOutputs) throws IOException {
|
||||
Set<String> requiredOutputs, boolean printArraysDebugging) throws IOException {
|
||||
log.info("\n\tRUNNING TEST " + modelName + "...");
|
||||
SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName);
|
||||
if(listeners != null){
|
||||
graph.setListeners(listeners);
|
||||
}
|
||||
|
||||
if(printArraysDebugging){
|
||||
graph.addListeners(new ExecPrintListener());
|
||||
}
|
||||
|
||||
if(requiredOutputs == null){
|
||||
requiredOutputs = graph.variableMap().keySet();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as
|
|||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
|
||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -194,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||
|
||||
try {
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||
} catch (Throwable t){
|
||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -20,11 +21,9 @@ import org.junit.*;
|
|||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
|
@ -51,9 +50,12 @@ public class TFGraphTestList {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
//Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step
|
||||
//Implemented internally using ExecPrintListener
|
||||
public static final boolean printArraysDebugging = false;
|
||||
|
||||
public static String[] modelNames = new String[]{
|
||||
// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
|
||||
"accumulate_n/rank0"
|
||||
"resize_nearest_neighbor/int32"
|
||||
};
|
||||
|
||||
@After
|
||||
|
@ -102,7 +104,7 @@ public class TFGraphTestList {
|
|||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith,
|
||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging);
|
||||
}
|
||||
|
||||
@Test @Ignore
|
||||
|
@ -110,7 +112,6 @@ public class TFGraphTestList {
|
|||
//Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
|
||||
File dir = testDir.newFolder();
|
||||
Map<String, INDArray> inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir);
|
||||
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir);
|
||||
|
||||
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
|
|||
currentTestDir = testDir.newFolder();
|
||||
log.info("----- SameDiff Exec: {} -----", modelName);
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF,
|
||||
LOADER, maxRE, minAbs);
|
||||
LOADER, maxRE, minAbs, false);
|
||||
|
||||
if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){
|
||||
log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: ");
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
package org.nd4j.imports.listeners;
|
||||
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* A very quick and dirty debugging listener
|
||||
* This listener just prints the outputs of any ops during execution
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ExecPrintListener extends BaseListener {
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
|
||||
for(INDArray arr : outputs){
|
||||
System.out.println(arr);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest {
|
|||
|
||||
@Test
|
||||
public void testL2() {
|
||||
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT,
|
||||
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT,
|
||||
LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY,
|
||||
LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE,
|
||||
LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,
|
||||
|
|
Loading…
Reference in New Issue