Various Fixes (#75)
* #8431 Cast loss function weights array automatically Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add 'regex verbose mode' printing (ExecDebugListener) for TFGraphTestAllSameDiff' Signed-off-by: AlexDBlack <blacka101@gmail.com> * Class import mapping fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't swallow first exception in NativeOpExecutioner.exec(CustomOp) Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
8d87b078c2
commit
e910ce75ec
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
|
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
|
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPreprocessorVertex(){
|
||||||
|
for(boolean withMinibatchDim : new boolean[]{true, false}){
|
||||||
|
long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
|
||||||
|
long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
|
||||||
|
|
||||||
|
for( long minibatch : new long[]{1, 3}) {
|
||||||
|
long[] inArrayShape = new long[]{minibatch, 32};
|
||||||
|
long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
|
||||||
|
long length = minibatch * 32;
|
||||||
|
|
||||||
|
INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
|
||||||
|
|
||||||
|
ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
INDArray expOut = in.reshape(targetArrayShape);
|
||||||
|
assertEquals(expOut, out);
|
||||||
|
|
||||||
|
INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
|
||||||
|
assertEquals(in, backprop);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
|
||||||
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape);
|
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
||||||
if (inputShape[0] != targetShape[0])
|
if (inputShape[0] != targetShape[0])
|
||||||
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
||||||
|
@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
|
||||||
} else {
|
} else {
|
||||||
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
} else {
|
} else {
|
||||||
if (inputShape[0] != targetShape[0])
|
if (inputShape[0] != targetShape[0])
|
||||||
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||||
}
|
}
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
||||||
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||||
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
||||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||||
val inputShape = new long[]{it.getSize()};
|
val inputShape = new long[]{it.getSize()};
|
||||||
if (targetShape.length == 3) {
|
if (targetShape.length == 3) {
|
||||||
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
||||||
}
|
}
|
||||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
|
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||||
}
|
}
|
||||||
return preprocessor;
|
return preprocessor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -20,7 +21,6 @@ import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||||
|
@ -36,42 +36,59 @@ import java.util.Arrays;
|
||||||
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generic reshape preprocessor
|
* Generic reshape preprocessor.
|
||||||
|
* Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
|
||||||
|
* is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}<br>
|
||||||
|
* For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:<br>
|
||||||
|
* hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR<br>
|
||||||
|
* hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
|
||||||
*
|
*
|
||||||
* @author Max Pumperla
|
* @author Max Pumperla
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
|
@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
|
||||||
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
private long[] inputShape;
|
private final long[] inputShape;
|
||||||
private long[] targetShape;
|
private final long[] targetShape;
|
||||||
private boolean hasMiniBatchDimension = false;
|
private boolean hasMiniBatchDimension;
|
||||||
private int miniBatchSize;
|
|
||||||
private long[] staticTargetShape;
|
|
||||||
|
|
||||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
/**
|
||||||
|
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
|
||||||
|
this(inputShape, targetShape, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
|
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||||
|
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||||
|
*/
|
||||||
|
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
||||||
|
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
|
||||||
this.inputShape = inputShape;
|
this.inputShape = inputShape;
|
||||||
this.targetShape = targetShape;
|
this.targetShape = targetShape;
|
||||||
|
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int prod(int[] array) {
|
private long[] getShape(long[] originalShape, long minibatch) {
|
||||||
int prod = 1;
|
long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
|
||||||
for (int i : array) {
|
if (newShape[0] != minibatch) {
|
||||||
prod *= i;
|
newShape = newShape.clone();
|
||||||
|
newShape[0] = minibatch;
|
||||||
}
|
}
|
||||||
return prod;
|
return newShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
|
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
|
||||||
int shapeLength = shape.length;
|
int shapeLength = shape.length;
|
||||||
val miniBatchShape = new long[shapeLength + 1];
|
val miniBatchShape = new long[shapeLength + 1];
|
||||||
for (int i = 0; i < miniBatchShape.length; i++) {
|
miniBatchShape[0] = miniBatchSize;
|
||||||
if (i == 0)
|
for (int i = 1; i < miniBatchShape.length; i++) {
|
||||||
miniBatchShape[i] = miniBatchSize;
|
|
||||||
else
|
|
||||||
miniBatchShape[i] = shape[i - 1];
|
miniBatchShape[i] = shape[i - 1];
|
||||||
}
|
}
|
||||||
return miniBatchShape;
|
return miniBatchShape;
|
||||||
|
@ -79,30 +96,12 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
// the target shape read from a keras config does not have mini-batch size
|
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
|
||||||
// included. We prepend it here dynamically.
|
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||||
|
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||||
|
|
||||||
long[] targetShape;
|
|
||||||
if (staticTargetShape != null){
|
|
||||||
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
|
|
||||||
hasMiniBatchDimension = true;
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
targetShape = this.targetShape;
|
|
||||||
}
|
|
||||||
if (!this.hasMiniBatchDimension) {
|
|
||||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
|
||||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
if (this.miniBatchSize != miniBatchSize) {
|
|
||||||
targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
|
|
||||||
inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
|
|
||||||
this.miniBatchSize = miniBatchSize;
|
|
||||||
}
|
|
||||||
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
||||||
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
|
||||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||||
}
|
}
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||||
|
@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||||
|
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
||||||
|
|
||||||
if (!Arrays.equals(targetShape, output.shape())) {
|
if (!Arrays.equals(targetShape, output.shape())) {
|
||||||
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
|
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
|
||||||
+ " (expected to be " + Arrays.toString(targetShape) + ")");
|
+ " (expected to be " + Arrays.toString(targetShape) + ")");
|
||||||
}
|
}
|
||||||
if (prodLong(output.shape()) == prodLong((targetShape))) {
|
if (prodLong(output.shape()) == prodLong((targetShape))) {
|
||||||
if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){
|
if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
|
||||||
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
|
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
|
||||||
}
|
}
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
|
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
|
||||||
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
|
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
|
||||||
|
@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
||||||
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
long[] shape = getShape(this.targetShape, 0);
|
||||||
InputType ret;
|
InputType ret;
|
||||||
switch (shape.length) {
|
switch (shape.length) {
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
ret = InputType.recurrent(shape[2], shape[1]);
|
ret = InputType.recurrent(shape[2], shape[1]);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
||||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||||
}else {
|
} else {
|
||||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
||||||
|
|
||||||
}
|
}
|
||||||
this.staticTargetShape = ret.getShape();
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void ReshapeEmbeddingConcatTest() throws Exception{
|
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||||
|
//TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
|
||||||
|
|
||||||
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||||
ComputationGraphConfiguration config =
|
ComputationGraphConfiguration config =
|
||||||
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||||
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
||||||
ComputationGraph model = new ComputationGraph(config);
|
ComputationGraph model = new ComputationGraph(config);
|
||||||
model.init();
|
model.init();
|
||||||
|
// System.out.println(model.summary());
|
||||||
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -540,6 +540,8 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.strict.Mish.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
|
||||||
|
|
|
@ -164,7 +164,7 @@ public class LossMCXENT implements ILossFunction {
|
||||||
throw new IllegalStateException("Weights vector (length " + weights.length()
|
throw new IllegalStateException("Weights vector (length " + weights.length()
|
||||||
+ ") does not match output.size(1)=" + output.size(1));
|
+ ") does not match output.size(1)=" + output.size(1));
|
||||||
}
|
}
|
||||||
INDArray temp = labels.mulRowVector(weights);
|
INDArray temp = labels.mulRowVector(weights.castTo(labels.dataType()));
|
||||||
INDArray col = temp.sum(true,1);
|
INDArray col = temp.sum(true,1);
|
||||||
grad = output.mulColumnVector(col).sub(temp);
|
grad = output.mulColumnVector(col).sub(temp);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -117,7 +117,7 @@ public class LossSparseMCXENT extends LossMCXENT {
|
||||||
|
|
||||||
private INDArray toOneHot(INDArray labels, INDArray preOutput){
|
private INDArray toOneHot(INDArray labels, INDArray preOutput){
|
||||||
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
|
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
|
||||||
"with last dimension having size 1. Got labels array with shape %ndShape", labels);
|
"with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", labels);
|
||||||
INDArray oneHotLabels = preOutput.ulike();
|
INDArray oneHotLabels = preOutput.ulike();
|
||||||
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
|
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
|
||||||
return oneHotLabels;
|
return oneHotLabels;
|
||||||
|
|
|
@ -1662,7 +1662,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
* This method executes given CustomOp
|
* This method executes given CustomOp
|
||||||
*
|
*
|
||||||
* PLEASE NOTE: You're responsible for input/output validation
|
* PLEASE NOTE: You're responsible for input/output validation
|
||||||
* @param op
|
* @param op Operation to execute
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] exec(@NonNull CustomOp op) {
|
public INDArray[] exec(@NonNull CustomOp op) {
|
||||||
|
@ -1671,11 +1671,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
try {
|
try {
|
||||||
val list = this.calculateOutputShape(op);
|
val list = this.calculateOutputShape(op);
|
||||||
if (list.isEmpty())
|
if (list.isEmpty())
|
||||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to calculate output datatypes");
|
||||||
|
|
||||||
for (LongShapeDescriptor shape : list)
|
for (LongShapeDescriptor shape : list)
|
||||||
op.addOutputArgument(Nd4j.create(shape, false));
|
op.addOutputArgument(Nd4j.create(shape, false));
|
||||||
|
} catch (ND4JIllegalStateException e){
|
||||||
|
throw e;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,9 @@ public class TestOpMapping extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
String opName = df.opName();
|
String opName = df.opName();
|
||||||
|
|
||||||
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
|
assertTrue("Op is missing - not defined in ImportClassMapping: " + opName +
|
||||||
|
"\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName)
|
||||||
|
);
|
||||||
|
|
||||||
try{
|
try{
|
||||||
String[] tfNames = df.tensorflowNames();
|
String[] tfNames = df.tensorflowNames();
|
||||||
|
|
|
@ -129,6 +129,13 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"resize_bilinear/int32.*"
|
"resize_bilinear/int32.*"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||||
|
all arrays printed during execution.
|
||||||
|
If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output
|
||||||
|
arrays will be printed during execution
|
||||||
|
*/
|
||||||
|
private final List<String> debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*");
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void beforeClass() {
|
public static void beforeClass() {
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
Nd4j.setDataType(DataType.FLOAT);
|
||||||
|
@ -194,8 +201,18 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst());
|
Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst());
|
||||||
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
|
||||||
|
|
||||||
|
boolean verboseDebugMode = false;
|
||||||
|
if(debugModeRegexes != null){
|
||||||
|
for(String regex : debugModeRegexes){
|
||||||
|
if(modelName.matches(regex)){
|
||||||
|
verboseDebugMode = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
|
TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode);
|
||||||
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
|
||||||
|
|
|
@ -20,13 +20,15 @@ import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
import org.nd4j.linalg.lossfunctions.impl.*;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertFalse;
|
import static junit.framework.TestCase.assertFalse;
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
|
@ -70,6 +72,71 @@ public class LossFunctionTest extends BaseNd4jTest {
|
||||||
assertEquals(0, match2);
|
assertEquals(0, match2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWeightedLossFunctionDTypes(){
|
||||||
|
|
||||||
|
for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
|
||||||
|
for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
|
||||||
|
for( boolean rank1W : new boolean[]{false, true}) {
|
||||||
|
|
||||||
|
INDArray preOut = Nd4j.rand(activationsDt, 2, 3);
|
||||||
|
INDArray l = Nd4j.rand(activationsDt, 2, 3);
|
||||||
|
|
||||||
|
INDArray w = Nd4j.createFromArray(1.0f, 2.0f, 3.0f).castTo(weightsDt);
|
||||||
|
if(!rank1W){
|
||||||
|
w = w.reshape(1, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
ILossFunction lf = null;
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
switch (i) {
|
||||||
|
case 0:
|
||||||
|
lf = new LossBinaryXENT(w);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
lf = new LossL1(w);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
lf = new LossL2(w);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
lf = new LossMAE(w);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
lf = new LossMAPE(w);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
lf = new LossMCXENT(w);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
lf = new LossMSE(w);
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
lf = new LossMSLE(w);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
lf = new LossNegativeLogLikelihood(w);
|
||||||
|
break;
|
||||||
|
case 9:
|
||||||
|
lf = new LossSparseMCXENT(w);
|
||||||
|
l = Nd4j.createFromArray(1,2).reshape(2, 1).castTo(activationsDt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Check score
|
||||||
|
lf.computeScore(l, preOut, new ActivationSoftmax(), null, true);
|
||||||
|
|
||||||
|
//Check backward
|
||||||
|
lf.computeGradient(l, preOut, new ActivationSoftmax(), null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
|
Loading…
Reference in New Issue