Various Fixes (#75)
* #8431 Cast loss function weights array automatically Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add 'regex verbose mode' printing (ExecDebugListener) for TFGraphTestAllSameDiff' Signed-off-by: AlexDBlack <blacka101@gmail.com> * Class import mapping fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't swallow first exception in NativeOpExecutioner.exec(CustomOp) Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
8d87b078c2
commit
e910ce75ec
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
|
@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
|
|||
|
||||
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testPreprocessorVertex(){
|
||||
for(boolean withMinibatchDim : new boolean[]{true, false}){
|
||||
long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
|
||||
long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
|
||||
|
||||
for( long minibatch : new long[]{1, 3}) {
|
||||
long[] inArrayShape = new long[]{minibatch, 32};
|
||||
long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
|
||||
long length = minibatch * 32;
|
||||
|
||||
INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
|
||||
|
||||
ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
|
||||
|
||||
for( int i=0; i<3; i++ ) {
|
||||
INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||
INDArray expOut = in.reshape(targetArrayShape);
|
||||
assertEquals(expOut, out);
|
||||
|
||||
INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||
assertEquals(in, backprop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
|
|||
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||
val inputShape = new long[]{it.getSize()};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
|
||||
}
|
||||
return preprocessor;
|
||||
}
|
||||
|
|
|
@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
|
|||
} else {
|
||||
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
||||
if (inputShape[0] != targetShape[0])
|
||||
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
}
|
||||
|
||||
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
||||
|
@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
|
|||
} else {
|
||||
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
} else {
|
||||
if (inputShape[0] != targetShape[0])
|
||||
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
}
|
||||
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
||||
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
||||
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||
val inputShape = new long[]{it.getSize()};
|
||||
if (targetShape.length == 3) {
|
||||
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||
}
|
||||
return preprocessor;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -20,7 +21,6 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
|
@ -36,73 +36,72 @@ import java.util.Arrays;
|
|||
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
||||
|
||||
/**
|
||||
* Generic reshape preprocessor
|
||||
* Generic reshape preprocessor.
|
||||
* Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
|
||||
* is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}<br>
|
||||
* For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:<br>
|
||||
* hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR<br>
|
||||
* hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
|
||||
*
|
||||
* @author Max Pumperla
|
||||
*/
|
||||
@Data
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
|
||||
@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
|
||||
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||
|
||||
private long[] inputShape;
|
||||
private long[] targetShape;
|
||||
private boolean hasMiniBatchDimension = false;
|
||||
private int miniBatchSize;
|
||||
private long[] staticTargetShape;
|
||||
private final long[] inputShape;
|
||||
private final long[] targetShape;
|
||||
private boolean hasMiniBatchDimension;
|
||||
|
||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
||||
this.inputShape = inputShape;
|
||||
this.targetShape = targetShape;
|
||||
/**
|
||||
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
|
||||
*/
|
||||
@Deprecated
|
||||
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
|
||||
this(inputShape, targetShape, false);
|
||||
}
|
||||
|
||||
private static int prod(int[] array) {
|
||||
int prod = 1;
|
||||
for (int i : array) {
|
||||
prod *= i;
|
||||
/**
|
||||
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||
*/
|
||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
||||
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
|
||||
this.inputShape = inputShape;
|
||||
this.targetShape = targetShape;
|
||||
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
||||
}
|
||||
|
||||
private long[] getShape(long[] originalShape, long minibatch) {
|
||||
long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
|
||||
if (newShape[0] != minibatch) {
|
||||
newShape = newShape.clone();
|
||||
newShape[0] = minibatch;
|
||||
}
|
||||
return prod;
|
||||
return newShape;
|
||||
}
|
||||
|
||||
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
|
||||
int shapeLength = shape.length;
|
||||
val miniBatchShape = new long[shapeLength + 1];
|
||||
for (int i = 0; i < miniBatchShape.length; i++) {
|
||||
if (i == 0)
|
||||
miniBatchShape[i] = miniBatchSize;
|
||||
else
|
||||
miniBatchShape[i] = shape[i - 1];
|
||||
miniBatchShape[0] = miniBatchSize;
|
||||
for (int i = 1; i < miniBatchShape.length; i++) {
|
||||
miniBatchShape[i] = shape[i - 1];
|
||||
}
|
||||
return miniBatchShape;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
// the target shape read from a keras config does not have mini-batch size
|
||||
// included. We prepend it here dynamically.
|
||||
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
|
||||
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||
|
||||
long[] targetShape;
|
||||
if (staticTargetShape != null){
|
||||
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
|
||||
hasMiniBatchDimension = true;
|
||||
this.miniBatchSize = miniBatchSize;
|
||||
}
|
||||
else{
|
||||
targetShape = this.targetShape;
|
||||
}
|
||||
if (!this.hasMiniBatchDimension) {
|
||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
||||
this.miniBatchSize = miniBatchSize;
|
||||
}
|
||||
if (this.miniBatchSize != miniBatchSize) {
|
||||
targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
|
||||
inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
|
||||
this.miniBatchSize = miniBatchSize;
|
||||
}
|
||||
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
||||
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||
|
@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
|
||||
@Override
|
||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||
|
||||
if (!Arrays.equals(targetShape, output.shape())) {
|
||||
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
|
||||
+ " (expected to be " + Arrays.toString(targetShape) + ")");
|
||||
}
|
||||
if (prodLong(output.shape()) == prodLong((targetShape))) {
|
||||
if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){
|
||||
if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
|
||||
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
|
||||
} else {
|
||||
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
|
||||
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
|
||||
|
@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
|
||||
@Override
|
||||
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
||||
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
||||
long[] shape = getShape(this.targetShape, 0);
|
||||
InputType ret;
|
||||
switch (shape.length) {
|
||||
case 2:
|
||||
|
@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
ret = InputType.recurrent(shape[2], shape[1]);
|
||||
break;
|
||||
case 4:
|
||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||
}else {
|
||||
} else {
|
||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
||||
|
||||
}
|
||||
this.staticTargetShape = ret.getShape();
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
|
||||
|
||||
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||
ComputationGraphConfiguration config =
|
||||
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
||||
ComputationGraph model = new ComputationGraph(config);
|
||||
model.init();
|
||||
// System.out.println(model.summary());
|
||||
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -540,6 +540,8 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Mish.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
|
||||
|
|
|
@ -164,7 +164,7 @@ public class LossMCXENT implements ILossFunction {
|
|||
throw new IllegalStateException("Weights vector (length " + weights.length()
|
||||
+ ") does not match output.size(1)=" + output.size(1));
|
||||
}
|
||||
INDArray temp = labels.mulRowVector(weights);
|
||||
INDArray temp = labels.mulRowVector(weights.castTo(labels.dataType()));
|
||||
INDArray col = temp.sum(true,1);
|
||||
grad = output.mulColumnVector(col).sub(temp);
|
||||
} else {
|
||||
|
|
|
@ -117,7 +117,7 @@ public class LossSparseMCXENT extends LossMCXENT {
|
|||
|
||||
private INDArray toOneHot(INDArray labels, INDArray preOutput){
|
||||
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
|
||||
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
|
||||
"with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", labels);
|
||||
INDArray oneHotLabels = preOutput.ulike();
|
||||
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
|
||||
return oneHotLabels;
|
||||
|
|
|
@ -1662,7 +1662,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
* This method executes given CustomOp
|
||||
*
|
||||
* PLEASE NOTE: You're responsible for input/output validation
|
||||
* @param op
|
||||
* @param op Operation to execute
|
||||
*/
|
||||
@Override
|
||||
public INDArray[] exec(@NonNull CustomOp op) {
|
||||
|
@ -1671,11 +1671,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
try {
|
||||
val list = this.calculateOutputShape(op);
|
||||
if (list.isEmpty())
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to calculate output datatypes");
|
||||
|
||||
for (LongShapeDescriptor shape : list)
|
||||
op.addOutputArgument(Nd4j.create(shape, false));
|
||||
|
||||
} catch (ND4JIllegalStateException e){
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||
}
|
||||
|
|
|
@ -68,7 +68,9 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
}
|
||||
String opName = df.opName();
|
||||
|
||||
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
|
||||
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName +
|
||||
"\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName)
|
||||
);
|
||||
|
||||
try{
|
||||
String[] tfNames = df.tensorflowNames();
|
||||
|
|
|
@ -129,6 +129,13 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"resize_bilinear/int32.*"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
all arrays printed during execution.
|
||||
If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output
|
||||
arrays will be printed during execution
|
||||
*/
|
||||
private final List<String> debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*");
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() {
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
|
@ -194,8 +201,18 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst());
|
||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||
|
||||
boolean verboseDebugMode = false;
|
||||
if(debugModeRegexes != null){
|
||||
for(String regex : debugModeRegexes){
|
||||
if(modelName.matches(regex)){
|
||||
verboseDebugMode = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode);
|
||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||
} catch (Throwable t){
|
||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||
|
|
|
@ -20,13 +20,15 @@ import org.junit.Test;
|
|||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
||||
import org.nd4j.linalg.lossfunctions.impl.*;
|
||||
|
||||
import static junit.framework.TestCase.assertFalse;
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
|
@ -70,6 +72,71 @@ public class LossFunctionTest extends BaseNd4jTest {
|
|||
assertEquals(0, match2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWeightedLossFunctionDTypes(){
|
||||
|
||||
for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
|
||||
for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
|
||||
for( boolean rank1W : new boolean[]{false, true}) {
|
||||
|
||||
INDArray preOut = Nd4j.rand(activationsDt, 2, 3);
|
||||
INDArray l = Nd4j.rand(activationsDt, 2, 3);
|
||||
|
||||
INDArray w = Nd4j.createFromArray(1.0f, 2.0f, 3.0f).castTo(weightsDt);
|
||||
if(!rank1W){
|
||||
w = w.reshape(1, 3);
|
||||
}
|
||||
|
||||
ILossFunction lf = null;
|
||||
for (int i = 0; i < 10; i++) {
|
||||
switch (i) {
|
||||
case 0:
|
||||
lf = new LossBinaryXENT(w);
|
||||
break;
|
||||
case 1:
|
||||
lf = new LossL1(w);
|
||||
break;
|
||||
case 2:
|
||||
lf = new LossL2(w);
|
||||
break;
|
||||
case 3:
|
||||
lf = new LossMAE(w);
|
||||
break;
|
||||
case 4:
|
||||
lf = new LossMAPE(w);
|
||||
break;
|
||||
case 5:
|
||||
lf = new LossMCXENT(w);
|
||||
break;
|
||||
case 6:
|
||||
lf = new LossMSE(w);
|
||||
break;
|
||||
case 7:
|
||||
lf = new LossMSLE(w);
|
||||
break;
|
||||
case 8:
|
||||
lf = new LossNegativeLogLikelihood(w);
|
||||
break;
|
||||
case 9:
|
||||
lf = new LossSparseMCXENT(w);
|
||||
l = Nd4j.createFromArray(1,2).reshape(2, 1).castTo(activationsDt);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
//Check score
|
||||
lf.computeScore(l, preOut, new ActivationSoftmax(), null, true);
|
||||
|
||||
//Check backward
|
||||
lf.computeGradient(l, preOut, new ActivationSoftmax(), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
|
|
Loading…
Reference in New Issue