Fixes for global pooling + masking with different mask datatypes (#212)

* Fixes for global pooling + masking with different mask datatypes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Global pooling backprop dtype fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2020-02-04 15:38:06 +11:00 committed by GitHub
parent ddf70ac450
commit 57d5eb473b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 22 deletions

View File

@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest {
} }
} }
} }
@Test
public void testMaskLayerDataTypes(){
for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE,
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64,
DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){
INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt);
for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){
INDArray in = Nd4j.rand(networkDtype, 2, 5, 10);
INDArray label1 = Nd4j.rand(networkDtype, 2, 5);
INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10);
for(PoolingType pt : PoolingType.values()) {
//System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new GlobalPoolingLayer(pt))
.layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.output(in, false, mask, null);
net.output(in, false, mask, null);
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.list()
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.output(in, false, mask, mask);
net2.output(in, false, mask, mask);
net.fit(in, label1, mask, null);
net2.fit(in, label2, mask, mask);
}
}
}
}
} }

View File

@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
applyDropOutIfNecessary(training, workspaceMgr); applyDropOutIfNecessary(training, workspaceMgr);
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM); INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training); INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
if (maskArray != null) { if (maskArray != null) {

View File

@ -56,6 +56,7 @@ public class MaskedReductionUtil {
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank()); throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
} }
toReduce = toReduce.castTo(dataType);
mask = mask.castTo(dataType); mask = mask.castTo(dataType);
//Sum pooling: easy. Multiply by mask, then sum as normal //Sum pooling: easy. Multiply by mask, then sum as normal
@ -64,13 +65,7 @@ public class MaskedReductionUtil {
switch (poolingType) { switch (poolingType) {
case MAX: case MAX:
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask = mask.castTo(dataType).rsub(1.0);
INDArray negInfMask;
if(mask.dataType() == DataType.BOOL){
negInfMask = Transforms.not(mask).castTo(dataType);
} else {
negInfMask = mask.rsub(1.0);
}
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape()); INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
@ -121,18 +116,14 @@ public class MaskedReductionUtil {
//Mask: [minibatch, tsLength] //Mask: [minibatch, tsLength]
//Epsilon: [minibatch, vectorSize] //Epsilon: [minibatch, vectorSize]
mask = mask.castTo(input.dataType());
switch (poolingType) { switch (poolingType) {
case MAX: case MAX:
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op INDArray negInfMask = mask.rsub(1.0);
INDArray negInfMask;
if(mask.dataType() == DataType.BOOL){
negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType());
} else {
negInfMask = mask.rsub(1.0);
}
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0)); BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
INDArray withInf = Nd4j.createUninitialized(input.shape()); INDArray withInf = Nd4j.createUninitialized(input.dataType(), input.shape());
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2)); Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op //At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
@ -145,7 +136,7 @@ public class MaskedReductionUtil {
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut //if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
//With masking: N differs for different time series //With masking: N differs for different time series
INDArray out = Nd4j.createUninitialized(input.shape(), 'f'); INDArray out = Nd4j.createUninitialized(input.dataType(), input.shape(), 'f');
//Broadcast copy op, then divide and mask to 0 as appropriate //Broadcast copy op, then divide and mask to 0 as appropriate
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1)); Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
@ -162,7 +153,7 @@ public class MaskedReductionUtil {
case PNORM: case PNORM:
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0 //Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
INDArray masked2 = Nd4j.createUninitialized(input.shape()); INDArray masked2 = Nd4j.createUninitialized(input.dataType(), input.shape());
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2)); Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
INDArray abs = Transforms.abs(masked2, true); INDArray abs = Transforms.abs(masked2, true);