DL4J: Add Sparse multi-class cross entropy loss function (#72)
* #8432 Add sparse mcxent loss Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for LossSparseMCXENT Signed-off-by: AlexDBlack <blacka101@gmail.com> * add simple debugging listener for SameDiff exec debugging Signed-off-by: AlexDBlack <blacka101@gmail.com> * Extra gradient check + header polishing Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
823bd0ff88
commit
4a2fedf3e7
|
@ -83,6 +83,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
new LossMultiLabel(), new LossWasserstein(),
|
new LossMultiLabel(), new LossWasserstein(),
|
||||||
|
new LossSparseMCXENT()
|
||||||
};
|
};
|
||||||
|
|
||||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||||
|
@ -116,7 +117,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
Activation.IDENTITY, // MixtureDensity
|
Activation.IDENTITY, // MixtureDensity
|
||||||
Activation.TANH, // MixtureDensity + tanh
|
Activation.TANH, // MixtureDensity + tanh
|
||||||
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
||||||
Activation.IDENTITY // Wasserstein
|
Activation.IDENTITY, // Wasserstein
|
||||||
|
Activation.SOFTMAX, //sparse MCXENT
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] nOut = new int[] {1, //xent
|
int[] nOut = new int[] {1, //xent
|
||||||
|
@ -151,6 +153,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
10, // Mixture Density + tanh
|
10, // Mixture Density + tanh
|
||||||
10, // MultiLabel
|
10, // MultiLabel
|
||||||
2, // Wasserstein
|
2, // Wasserstein
|
||||||
|
4, //sparse MCXENT
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] minibatchSizes = new int[] {1, 3};
|
int[] minibatchSizes = new int[] {1, 3};
|
||||||
|
@ -233,7 +236,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
|
new LossSquaredHinge(), new LossFMeasure(), new LossFMeasure(2.0), new LossFMeasure(),
|
||||||
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
new LossFMeasure(2.0), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
new LossMultiLabel(), new LossWasserstein()
|
new LossMultiLabel(), new LossWasserstein(),
|
||||||
|
new LossSparseMCXENT()
|
||||||
};
|
};
|
||||||
|
|
||||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||||
|
@ -266,7 +270,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
Activation.IDENTITY, // MixtureDensity
|
Activation.IDENTITY, // MixtureDensity
|
||||||
Activation.TANH, // MixtureDensity + tanh
|
Activation.TANH, // MixtureDensity + tanh
|
||||||
Activation.TANH, // MultiLabel
|
Activation.TANH, // MultiLabel
|
||||||
Activation.IDENTITY // Wasserstein
|
Activation.IDENTITY, // Wasserstein
|
||||||
|
Activation.SOFTMAX
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] nOut = new int[] {1, //xent
|
int[] nOut = new int[] {1, //xent
|
||||||
|
@ -300,6 +305,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
10, // Mixture Density + tanh
|
10, // Mixture Density + tanh
|
||||||
10, // MultiLabel
|
10, // MultiLabel
|
||||||
2, // Wasserstein
|
2, // Wasserstein
|
||||||
|
4
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] minibatchSizes = new int[] {1, 3};
|
int[] minibatchSizes = new int[] {1, 3};
|
||||||
|
@ -476,6 +482,23 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
case "LossSparseMCXENT":
|
||||||
|
if (labelsShape.length == 2) {
|
||||||
|
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1);
|
||||||
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
|
ret[1].putScalar(i, 0, r.nextInt((int) labelsShape[1]));
|
||||||
|
}
|
||||||
|
} else if (labelsShape.length == 3) {
|
||||||
|
ret[1] = Nd4j.create(DataType.INT, labelsShape[0], 1, labelsShape[2]);
|
||||||
|
for (int i = 0; i < labelsShape[0]; i++) {
|
||||||
|
for (int j = 0; j < labelsShape[2]; j++) {
|
||||||
|
ret[1].putScalar(i, 0, j, r.nextInt((int) labelsShape[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case "LossHinge":
|
case "LossHinge":
|
||||||
case "LossSquaredHinge":
|
case "LossSquaredHinge":
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@ -61,19 +62,12 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
int nOut = 2;
|
int nOut = 2;
|
||||||
int miniBatchSize = 3;
|
int miniBatchSize = 3;
|
||||||
|
|
||||||
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT()};
|
ILossFunction[] lfs = new ILossFunction[]{new LossMSE(), new LossMCXENT(), new LossSparseMCXENT()};
|
||||||
|
|
||||||
for (int maskType = 0; maskType < 3; maskType++) {
|
for (int maskType = 0; maskType < 3; maskType++) {
|
||||||
|
|
||||||
Random r = new Random(12345L);
|
Random r = new Random(12345L);
|
||||||
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
|
INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength});
|
||||||
INDArray labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
|
||||||
for (int i = 0; i < miniBatchSize; i++) {
|
|
||||||
for (int j = 0; j < timeSeriesLength; j++) {
|
|
||||||
int idx = r.nextInt(nOut);
|
|
||||||
labels.putScalar(new int[]{i, idx, j}, 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray labelMask;
|
INDArray labelMask;
|
||||||
String mt;
|
String mt;
|
||||||
|
@ -85,13 +79,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
//Per time step masking
|
//Per time step masking
|
||||||
labelMask = Nd4j.createUninitialized(miniBatchSize, timeSeriesLength);
|
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, timeSeriesLength);
|
||||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
||||||
mt = "PerTimeStep";
|
mt = "PerTimeStep";
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
//Per output masking:
|
//Per output masking:
|
||||||
labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, nOut, timeSeriesLength});
|
labelMask = Nd4j.createUninitialized(DataType.DOUBLE, miniBatchSize, nOut, timeSeriesLength);
|
||||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5));
|
||||||
mt = "PerOutput";
|
mt = "PerOutput";
|
||||||
break;
|
break;
|
||||||
|
@ -101,6 +95,26 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
|
|
||||||
for (ILossFunction lf : lfs) {
|
for (ILossFunction lf : lfs) {
|
||||||
|
|
||||||
|
INDArray labels;
|
||||||
|
if(lf instanceof LossSparseMCXENT){
|
||||||
|
labels = Nd4j.zeros(miniBatchSize, 1, timeSeriesLength);
|
||||||
|
for (int i = 0; i < miniBatchSize; i++) {
|
||||||
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
|
int idx = r.nextInt(nOut);
|
||||||
|
labels.putScalar(new int[]{i, 0, j}, idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
labels = Nd4j.zeros(miniBatchSize, nOut, timeSeriesLength);
|
||||||
|
for (int i = 0; i < miniBatchSize; i++) {
|
||||||
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
|
int idx = r.nextInt(nOut);
|
||||||
|
labels.putScalar(new int[]{i, idx, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
|
Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX;
|
||||||
|
|
||||||
MultiLayerConfiguration conf =
|
MultiLayerConfiguration conf =
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -26,16 +27,27 @@ import org.nd4j.linalg.lossfunctions.impl.*;
|
||||||
public class LossFunctions {
|
public class LossFunctions {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* MSE: Mean Squared Error: Linear Regression<br>
|
* MSE: Mean Squared Error: Linear Regression - {@link LossMSE}<br>
|
||||||
* EXPLL: Exponential log likelihood: Poisson Regression<br>
|
* l1: L1 loss (absolute value) - {@link LossL1}<br>
|
||||||
* XENT: Cross Entropy: Binary Classification<br>
|
* XENT: Cross Entropy: Binary Classification - {@link LossBinaryXENT}<br>
|
||||||
* MCXENT: Multiclass Cross Entropy<br>
|
* MCXENT: Multiclass Cross Entropy - {@link LossMCXENT}<br>
|
||||||
* RMSE_XENT: RMSE Cross Entropy<br>
|
* SPARSE_MCXENT: Sparse multi-class cross entropy - {@link LossSparseMCXENT}<br>
|
||||||
* SQUARED_LOSS: Squared Loss<br>
|
* SQUARED_LOSS: Alias for mean squared error - {@link LossMSE}<br>
|
||||||
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood<br>
|
* NEGATIVELOGLIKELIHOOD: Negative Log Likelihood - {@link LossNegativeLogLikelihood}<br>
|
||||||
|
* COSINE_PROXIMITY: Cosine proximity loss - {@link LossCosineProximity}<br>
|
||||||
|
* HINGE: Hinge loss - {@link LossHinge}<br>
|
||||||
|
* SQUARED_HINGE: Squared hinge loss - {@link LossSquaredHinge}<br>
|
||||||
|
* KL_DIVERGENCE: Kullback-Leibler divergence loss - {@link LossKLD}<br>
|
||||||
|
* MEAN_ABSOLUTE_ERROR: mean absolute error loss - {@link LossMAE}<br>
|
||||||
|
* L2: L2 loss (sum of squared errors) - {@link LossL2}<br>
|
||||||
|
* MEAN_ABSOLUTE_PERCENTAGE_ERROR: MAPE loss - {@link LossMAPE}<br>
|
||||||
|
* MEAN_SQUARED_LOGARITHMIC_ERROR: MSLE loss - {@link LossMSLE}<br>
|
||||||
|
* POISSON: Poisson loss - {@link LossPoisson}<br>
|
||||||
|
* WASSERSTEIN: Wasserstein loss - {@link LossWasserstein}
|
||||||
*/
|
*/
|
||||||
public enum LossFunction {
|
public enum LossFunction {
|
||||||
MSE, L1, @Deprecated EXPLL, XENT, MCXENT, @Deprecated RMSE_XENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, @Deprecated CUSTOM, COSINE_PROXIMITY, HINGE, SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
|
MSE, L1, XENT, MCXENT, SPARSE_MCXENT, SQUARED_LOSS, RECONSTRUCTION_CROSSENTROPY, NEGATIVELOGLIKELIHOOD, COSINE_PROXIMITY, HINGE,
|
||||||
|
SQUARED_HINGE, KL_DIVERGENCE, MEAN_ABSOLUTE_ERROR, L2, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_LOGARITHMIC_ERROR, POISSON, WASSERSTEIN;
|
||||||
|
|
||||||
public ILossFunction getILossFunction() {
|
public ILossFunction getILossFunction() {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
|
@ -48,6 +60,8 @@ public class LossFunctions {
|
||||||
return new LossBinaryXENT();
|
return new LossBinaryXENT();
|
||||||
case MCXENT:
|
case MCXENT:
|
||||||
return new LossMCXENT();
|
return new LossMCXENT();
|
||||||
|
case SPARSE_MCXENT:
|
||||||
|
return new LossSparseMCXENT();
|
||||||
case KL_DIVERGENCE:
|
case KL_DIVERGENCE:
|
||||||
case RECONSTRUCTION_CROSSENTROPY:
|
case RECONSTRUCTION_CROSSENTROPY:
|
||||||
return new LossKLD();
|
return new LossKLD();
|
||||||
|
@ -68,7 +82,6 @@ public class LossFunctions {
|
||||||
case MEAN_SQUARED_LOGARITHMIC_ERROR:
|
case MEAN_SQUARED_LOGARITHMIC_ERROR:
|
||||||
return new LossMSLE();
|
return new LossMSLE();
|
||||||
case POISSON:
|
case POISSON:
|
||||||
case EXPLL:
|
|
||||||
return new LossPoisson();
|
return new LossPoisson();
|
||||||
case WASSERSTEIN:
|
case WASSERSTEIN:
|
||||||
return new LossWasserstein();
|
return new LossWasserstein();
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -40,10 +41,13 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* Multi-Class Cross Entropy loss function:<br>
|
* Multi-Class Cross Entropy loss function:<br>
|
||||||
* L = sum_i actual_i * log( predicted_i )
|
* L = sum_i actual_i * log( predicted_i )<br>
|
||||||
|
* Note that labels are represented by a one-hot distribution<br>
|
||||||
|
* See {@link LossSparseMCXENT} for the equivalent but with labels as integers instead
|
||||||
*
|
*
|
||||||
* @author Alex Black, Susan Eraly
|
* @author Alex Black, Susan Eraly
|
||||||
* @see LossNegativeLogLikelihood
|
* @see LossNegativeLogLikelihood
|
||||||
|
* @see LossSparseMCXENT
|
||||||
*/
|
*/
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@ -53,9 +57,9 @@ public class LossMCXENT implements ILossFunction {
|
||||||
|
|
||||||
@JsonSerialize(using = NDArrayTextSerializer.class)
|
@JsonSerialize(using = NDArrayTextSerializer.class)
|
||||||
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
|
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
|
||||||
private INDArray weights;
|
protected INDArray weights;
|
||||||
|
|
||||||
private double softmaxClipEps;
|
protected double softmaxClipEps;
|
||||||
|
|
||||||
public LossMCXENT() {
|
public LossMCXENT() {
|
||||||
this(null);
|
this(null);
|
||||||
|
@ -91,7 +95,7 @@ public class LossMCXENT implements ILossFunction {
|
||||||
this.softmaxClipEps = softmaxClipEps;
|
this.softmaxClipEps = softmaxClipEps;
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
protected INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
if(!labels.equalShapes(preOutput)){
|
if(!labels.equalShapes(preOutput)){
|
||||||
Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
|
Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.lossfunctions.impl;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossUtil;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
|
||||||
|
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||||
|
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* Sparse Multi-Class Cross Entropy loss function:<br>
|
||||||
|
* L = sum_i actual_i * log( predicted_i )<br>
|
||||||
|
* Note: this is the same loss function as {@link LossMCXENT}, the only difference being the format for the labels -
|
||||||
|
* this loss function uses integer indices (zero indexed) for the loss array, whereas LossMCXENT uses the equivalent
|
||||||
|
* one-hot representation
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
* @see LossNegativeLogLikelihood
|
||||||
|
* @see LossMCXENT
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@Getter @Setter
|
||||||
|
public class LossSparseMCXENT extends LossMCXENT {
|
||||||
|
private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1e-10;
|
||||||
|
|
||||||
|
public LossSparseMCXENT() {
|
||||||
|
this(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a flags scalar value.
|
||||||
|
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
|
||||||
|
* A weight vector of 1s should give identical results to no weight vector.
|
||||||
|
*
|
||||||
|
* @param weights Weights array (row vector). May be null.
|
||||||
|
*/
|
||||||
|
public LossSparseMCXENT(INDArray weights) {
|
||||||
|
this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multi-Class Cross Entropy loss function where each the output is (optionally) weighted/scaled by a fixed scalar value.
|
||||||
|
* Note that the weights array must be a row vector, of length equal to the labels/output dimension 1 size.
|
||||||
|
* A weight vector of 1s should give identical results to no weight vector.
|
||||||
|
*
|
||||||
|
* @param weights Weights array (row vector). May be null.
|
||||||
|
*/
|
||||||
|
public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double softmaxClipEps, @JsonProperty("weights") INDArray weights) {
|
||||||
|
super(softmaxClipEps, weights);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected INDArray sparseScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
|
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||||
|
return super.scoreArray(oneHotLabels, preOutput, activationFn, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
|
||||||
|
boolean average) {
|
||||||
|
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||||
|
return super.computeScore(oneHotLabels, preOutput, activationFn, mask, average);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
|
INDArray scoreArr = sparseScoreArray(labels, preOutput, activationFn, mask);
|
||||||
|
return scoreArr.sum(true,1).muli(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
|
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||||
|
return super.computeGradient(oneHotLabels, preOutput, activationFn, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
|
||||||
|
INDArray mask, boolean average) {
|
||||||
|
INDArray oneHotLabels = toOneHot(labels, preOutput);
|
||||||
|
return new Pair<>(super.computeScore(oneHotLabels, preOutput, activationFn, mask, average),
|
||||||
|
super.computeGradient(oneHotLabels, preOutput, activationFn, mask));
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray toOneHot(INDArray labels, INDArray preOutput){
|
||||||
|
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
|
||||||
|
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
|
||||||
|
INDArray oneHotLabels = preOutput.ulike();
|
||||||
|
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
|
||||||
|
return oneHotLabels;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
if (weights == null)
|
||||||
|
return "LossSparseMCXENT()";
|
||||||
|
return "LossSparseMCXENT(weights=" + weights + ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -40,6 +41,7 @@ import org.nd4j.autodiff.validation.TestCase;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
|
import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.imports.listeners.ExecPrintListener;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -137,7 +139,7 @@ public class TFGraphTestAllHelper {
|
||||||
|
|
||||||
protected static void checkOnlyOutput(Map<String, INDArray> inputs, Map<String, INDArray> predictions, String modelName,
|
protected static void checkOnlyOutput(Map<String, INDArray> inputs, Map<String, INDArray> predictions, String modelName,
|
||||||
String baseDir, String modelFilename, ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
String baseDir, String modelFilename, ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
||||||
Double maxRelErrorOverride, Double minAbsErrorOverride) throws IOException {
|
Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException {
|
||||||
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
||||||
" must be null or both must be provided");
|
" must be null or both must be provided");
|
||||||
Nd4j.EPS_THRESHOLD = 1e-3;
|
Nd4j.EPS_THRESHOLD = 1e-3;
|
||||||
|
@ -152,7 +154,7 @@ public class TFGraphTestAllHelper {
|
||||||
outputsToCheck.add(s);
|
outputsToCheck.add(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck);
|
Pair<SameDiff,Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFilename, modelName, inputs, execType, loader, null, outputsToCheck, printArraysDebugging);
|
||||||
SameDiff graph = p.getFirst();
|
SameDiff graph = p.getFirst();
|
||||||
Map<String,INDArray> sameDiffPredictions = p.getSecond();
|
Map<String,INDArray> sameDiffPredictions = p.getSecond();
|
||||||
|
|
||||||
|
@ -296,18 +298,18 @@ public class TFGraphTestAllHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
||||||
ExecuteWith execType, File localTestDir) throws IOException {
|
ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException {
|
||||||
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir);
|
checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
public static void checkIntermediate(Map<String, INDArray> inputs, String modelName, String baseDir, String modelFileName,
|
||||||
ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
ExecuteWith execType, BiFunction<File,String,SameDiff> loader,
|
||||||
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir) throws IOException {
|
Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException {
|
||||||
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" +
|
||||||
" must be null or both must be provided");
|
" must be null or both must be provided");
|
||||||
Nd4j.EPS_THRESHOLD = 1e-3;
|
Nd4j.EPS_THRESHOLD = 1e-3;
|
||||||
OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order
|
OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order
|
||||||
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null);
|
Pair<SameDiff, Map<String,INDArray>> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging);
|
||||||
SameDiff graph = p.getFirst();
|
SameDiff graph = p.getFirst();
|
||||||
Map<String,INDArray> sdPredictions = p.getSecond();
|
Map<String,INDArray> sdPredictions = p.getSecond();
|
||||||
|
|
||||||
|
@ -388,13 +390,17 @@ public class TFGraphTestAllHelper {
|
||||||
|
|
||||||
public static Pair<SameDiff, Map<String,INDArray>> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map<String, INDArray> inputs,
|
public static Pair<SameDiff, Map<String,INDArray>> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map<String, INDArray> inputs,
|
||||||
ExecuteWith executeWith, BiFunction<File,String,SameDiff> graphLoaderFunction, List<Listener> listeners,
|
ExecuteWith executeWith, BiFunction<File,String,SameDiff> graphLoaderFunction, List<Listener> listeners,
|
||||||
Set<String> requiredOutputs) throws IOException {
|
Set<String> requiredOutputs, boolean printArraysDebugging) throws IOException {
|
||||||
log.info("\n\tRUNNING TEST " + modelName + "...");
|
log.info("\n\tRUNNING TEST " + modelName + "...");
|
||||||
SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName);
|
SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName);
|
||||||
if(listeners != null){
|
if(listeners != null){
|
||||||
graph.setListeners(listeners);
|
graph.setListeners(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(printArraysDebugging){
|
||||||
|
graph.addListeners(new ExecPrintListener());
|
||||||
|
}
|
||||||
|
|
||||||
if(requiredOutputs == null){
|
if(requiredOutputs == null){
|
||||||
requiredOutputs = graph.variableMap().keySet();
|
requiredOutputs = graph.variableMap().keySet();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -182,7 +183,7 @@ public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as
|
||||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||||
|
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH,
|
||||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
||||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH);
|
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -194,7 +195,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
||||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -20,11 +21,9 @@ import org.junit.*;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
@ -51,9 +50,12 @@ public class TFGraphTestList {
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
//Only enable this for debugging, and leave it disabled for normal testing and CI - it prints all arrays for every execution step
|
||||||
|
//Implemented internally using ExecPrintListener
|
||||||
|
public static final boolean printArraysDebugging = false;
|
||||||
|
|
||||||
public static String[] modelNames = new String[]{
|
public static String[] modelNames = new String[]{
|
||||||
// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
|
"resize_nearest_neighbor/int32"
|
||||||
"accumulate_n/rank0"
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@After
|
@After
|
||||||
|
@ -102,7 +104,7 @@ public class TFGraphTestList {
|
||||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||||
|
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith,
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith,
|
||||||
TFGraphTestAllHelper.LOADER, maxRE, minAbs);
|
TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test @Ignore
|
@Test @Ignore
|
||||||
|
@ -110,7 +112,6 @@ public class TFGraphTestList {
|
||||||
//Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
|
//Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
|
||||||
File dir = testDir.newFolder();
|
File dir = testDir.newFolder();
|
||||||
Map<String, INDArray> inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir);
|
Map<String, INDArray> inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir);
|
||||||
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir);
|
TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -274,7 +275,7 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we
|
||||||
currentTestDir = testDir.newFolder();
|
currentTestDir = testDir.newFolder();
|
||||||
log.info("----- SameDiff Exec: {} -----", modelName);
|
log.info("----- SameDiff Exec: {} -----", modelName);
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF,
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF,
|
||||||
LOADER, maxRE, minAbs);
|
LOADER, maxRE, minAbs, false);
|
||||||
|
|
||||||
if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){
|
if(ArrayUtils.contains(IGNORE_REGEXES_LIBND4J_EXEC_ONLY, modelName)){
|
||||||
log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: ");
|
log.warn("\n\tIGNORING MODEL FOR LIBND4J EXECUTION ONLY: ");
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package org.nd4j.imports.listeners;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.listeners.At;
|
||||||
|
import org.nd4j.autodiff.listeners.BaseListener;
|
||||||
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A very quick and dirty debugging listener
|
||||||
|
* This listener just prints the outputs of any ops during execution
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class ExecPrintListener extends BaseListener {
|
||||||
|
@Override
|
||||||
|
public boolean isActive(Operation operation) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||||
|
System.out.println("------ Op: " + op.getName() + " - opName = " + op.getOp().opName() + ", class = " + op.getOp().getClass().getName() + " ------");
|
||||||
|
for(INDArray arr : outputs){
|
||||||
|
System.out.println(arr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -50,7 +50,7 @@ public class TestLossFunctionsSizeChecks extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testL2() {
|
public void testL2() {
|
||||||
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.EXPLL, LossFunction.XENT,
|
LossFunction[] lossFunctionList = {LossFunction.MSE, LossFunction.L1, LossFunction.XENT,
|
||||||
LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY,
|
LossFunction.MCXENT, LossFunction.SQUARED_LOSS, LossFunction.RECONSTRUCTION_CROSSENTROPY,
|
||||||
LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE,
|
LossFunction.NEGATIVELOGLIKELIHOOD, LossFunction.COSINE_PROXIMITY, LossFunction.HINGE,
|
||||||
LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,
|
LossFunction.SQUARED_HINGE, LossFunction.KL_DIVERGENCE, LossFunction.MEAN_ABSOLUTE_ERROR,
|
||||||
|
|
Loading…
Reference in New Issue