Various Fixes (#75)

* #8431 Cast loss function weights array automatically

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add 'regex verbose mode' printing (ExecDebugListener) for TFGraphTestAllSameDiff'

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Class import mapping fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Reshape fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Don't swallow first exception in NativeOpExecutioner.exec(CustomOp)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-23 20:06:12 +11:00 committed by GitHub
parent 8d87b078c2
commit e910ce75ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 189 additions and 67 deletions

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
}
@Test
public void testPreprocessorVertex(){
for(boolean withMinibatchDim : new boolean[]{true, false}){
long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
for( long minibatch : new long[]{1, 3}) {
long[] inArrayShape = new long[]{minibatch, 32};
long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
long length = minibatch * 32;
INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
for( int i=0; i<3; i++ ) {
INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
INDArray expOut = in.reshape(targetArrayShape);
assertEquals(expOut, out);
INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
assertEquals(in, backprop);
}
}
}
}
}

View File

@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape);
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
}
return preprocessor;
}

View File

@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
}
return preprocessor;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -20,7 +21,6 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -36,73 +36,72 @@ import java.util.Arrays;
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
/**
* Generic reshape preprocessor
* Generic reshape preprocessor.
* Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
* is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}<br>
* For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:<br>
* hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR<br>
* hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
*
* @author Max Pumperla
*/
@Data
@Slf4j
@EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor {
private long[] inputShape;
private long[] targetShape;
private boolean hasMiniBatchDimension = false;
private int miniBatchSize;
private long[] staticTargetShape;
private final long[] inputShape;
private final long[] targetShape;
private boolean hasMiniBatchDimension;
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
this.inputShape = inputShape;
this.targetShape = targetShape;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
private static int prod(int[] array) {
int prod = 1;
for (int i : array) {
prod *= i;
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
this.inputShape = inputShape;
this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
}
private long[] getShape(long[] originalShape, long minibatch) {
long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
if (newShape[0] != minibatch) {
newShape = newShape.clone();
newShape[0] = minibatch;
}
return prod;
return newShape;
}
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
int shapeLength = shape.length;
val miniBatchShape = new long[shapeLength + 1];
for (int i = 0; i < miniBatchShape.length; i++) {
if (i == 0)
miniBatchShape[i] = miniBatchSize;
else
miniBatchShape[i] = shape[i - 1];
miniBatchShape[0] = miniBatchSize;
for (int i = 1; i < miniBatchShape.length; i++) {
miniBatchShape[i] = shape[i - 1];
}
return miniBatchShape;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size
// included. We prepend it here dynamically.
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
long[] targetShape = getShape(this.targetShape, miniBatchSize);
long[] inputShape = getShape(this.inputShape, miniBatchSize);
long[] targetShape;
if (staticTargetShape != null){
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize;
}
else{
targetShape = this.targetShape;
}
if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
this.miniBatchSize = miniBatchSize;
}
if (this.miniBatchSize != miniBatchSize) {
targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
this.miniBatchSize = miniBatchSize;
}
if (prodLong(input.shape()) == prodLong((targetShape))) {
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
long[] targetShape = getShape(this.targetShape, miniBatchSize);
long[] inputShape = getShape(this.inputShape, miniBatchSize);
if (!Arrays.equals(targetShape, output.shape())) {
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
+ " (expected to be " + Arrays.toString(targetShape) + ")");
}
if (prodLong(output.shape()) == prodLong((targetShape))) {
if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){
if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
} else {
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
long[] shape = getShape(this.targetShape, 0);
InputType ret;
switch (shape.length) {
case 2:
@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.recurrent(shape[2], shape[1]);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
}else {
} else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
}
break;
default:
throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape));
}
this.staticTargetShape = ret.getShape();
return ret;
}
}

View File

@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
@Test
public void ReshapeEmbeddingConcatTest() throws Exception{
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
// System.out.println(model.summary());
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
}
}

View File

@ -540,6 +540,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Mish.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,

View File

@ -164,7 +164,7 @@ public class LossMCXENT implements ILossFunction {
throw new IllegalStateException("Weights vector (length " + weights.length()
+ ") does not match output.size(1)=" + output.size(1));
}
INDArray temp = labels.mulRowVector(weights);
INDArray temp = labels.mulRowVector(weights.castTo(labels.dataType()));
INDArray col = temp.sum(true,1);
grad = output.mulColumnVector(col).sub(temp);
} else {

View File

@ -117,7 +117,7 @@ public class LossSparseMCXENT extends LossMCXENT {
private INDArray toOneHot(INDArray labels, INDArray preOutput){
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
"with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", labels);
INDArray oneHotLabels = preOutput.ulike();
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
return oneHotLabels;

View File

@ -1662,7 +1662,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
* This method executes given CustomOp
*
* PLEASE NOTE: You're responsible for input/output validation
* @param op
* @param op Operation to execute
*/
@Override
public INDArray[] exec(@NonNull CustomOp op) {
@ -1671,11 +1671,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
try {
val list = this.calculateOutputShape(op);
if (list.isEmpty())
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to calculate output datatypes");
for (LongShapeDescriptor shape : list)
op.addOutputArgument(Nd4j.create(shape, false));
} catch (ND4JIllegalStateException e){
throw e;
} catch (Exception e) {
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
}

View File

@ -68,7 +68,9 @@ public class TestOpMapping extends BaseNd4jTest {
}
String opName = df.opName();
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName +
"\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName)
);
try{
String[] tfNames = df.tensorflowNames();

View File

@ -129,6 +129,13 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"resize_bilinear/int32.*"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
all arrays printed during execution.
If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output
arrays will be printed during execution
*/
private final List<String> debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*");
@BeforeClass
public static void beforeClass() {
Nd4j.setDataType(DataType.FLOAT);
@ -194,8 +201,18 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst());
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
boolean verboseDebugMode = false;
if(debugModeRegexes != null){
for(String regex : debugModeRegexes){
if(modelName.matches(regex)){
verboseDebugMode = true;
break;
}
}
}
try {
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
} catch (Throwable t){
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);

View File

@ -20,13 +20,15 @@ import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.*;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertTrue;
@ -70,6 +72,71 @@ public class LossFunctionTest extends BaseNd4jTest {
assertEquals(0, match2);
}
@Test
public void testWeightedLossFunctionDTypes(){
for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
for( boolean rank1W : new boolean[]{false, true}) {
INDArray preOut = Nd4j.rand(activationsDt, 2, 3);
INDArray l = Nd4j.rand(activationsDt, 2, 3);
INDArray w = Nd4j.createFromArray(1.0f, 2.0f, 3.0f).castTo(weightsDt);
if(!rank1W){
w = w.reshape(1, 3);
}
ILossFunction lf = null;
for (int i = 0; i < 10; i++) {
switch (i) {
case 0:
lf = new LossBinaryXENT(w);
break;
case 1:
lf = new LossL1(w);
break;
case 2:
lf = new LossL2(w);
break;
case 3:
lf = new LossMAE(w);
break;
case 4:
lf = new LossMAPE(w);
break;
case 5:
lf = new LossMCXENT(w);
break;
case 6:
lf = new LossMSE(w);
break;
case 7:
lf = new LossMSLE(w);
break;
case 8:
lf = new LossNegativeLogLikelihood(w);
break;
case 9:
lf = new LossSparseMCXENT(w);
l = Nd4j.createFromArray(1,2).reshape(2, 1).castTo(activationsDt);
break;
default:
throw new RuntimeException();
}
}
//Check score
lf.computeScore(l, preOut, new ActivationSoftmax(), null, true);
//Check backward
lf.computeGradient(l, preOut, new ActivationSoftmax(), null);
}
}
}
}
@Override
public char ordering() {