Fixes for global pooling + masking with different mask datatypes (#212)
* Fixes for global pooling + masking with different mask datatypes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Global pooling backprop dtype fixes Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
ddf70ac450
commit
57d5eb473b
|
@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMaskLayerDataTypes(){
|
||||||
|
|
||||||
|
for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE,
|
||||||
|
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64,
|
||||||
|
DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){
|
||||||
|
INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt);
|
||||||
|
|
||||||
|
for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||||
|
INDArray label1 = Nd4j.rand(networkDtype, 2, 5);
|
||||||
|
INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||||
|
|
||||||
|
for(PoolingType pt : PoolingType.values()) {
|
||||||
|
//System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.list()
|
||||||
|
.layer(new GlobalPoolingLayer(pt))
|
||||||
|
.layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
net.output(in, false, mask, null);
|
||||||
|
net.output(in, false, mask, null);
|
||||||
|
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.list()
|
||||||
|
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
net2.output(in, false, mask, mask);
|
||||||
|
net2.output(in, false, mask, mask);
|
||||||
|
|
||||||
|
net.fit(in, label1, mask, null);
|
||||||
|
net2.fit(in, label2, mask, mask);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
||||||
|
|
||||||
applyDropOutIfNecessary(training, workspaceMgr);
|
applyDropOutIfNecessary(training, workspaceMgr);
|
||||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM);
|
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||||
|
|
||||||
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
||||||
if (maskArray != null) {
|
if (maskArray != null) {
|
||||||
|
|
|
@ -56,6 +56,7 @@ public class MaskedReductionUtil {
|
||||||
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
|
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toReduce = toReduce.castTo(dataType);
|
||||||
mask = mask.castTo(dataType);
|
mask = mask.castTo(dataType);
|
||||||
|
|
||||||
//Sum pooling: easy. Multiply by mask, then sum as normal
|
//Sum pooling: easy. Multiply by mask, then sum as normal
|
||||||
|
@ -64,13 +65,7 @@ public class MaskedReductionUtil {
|
||||||
|
|
||||||
switch (poolingType) {
|
switch (poolingType) {
|
||||||
case MAX:
|
case MAX:
|
||||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
INDArray negInfMask = mask.castTo(dataType).rsub(1.0);
|
||||||
INDArray negInfMask;
|
|
||||||
if(mask.dataType() == DataType.BOOL){
|
|
||||||
negInfMask = Transforms.not(mask).castTo(dataType);
|
|
||||||
} else {
|
|
||||||
negInfMask = mask.rsub(1.0);
|
|
||||||
}
|
|
||||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||||
|
|
||||||
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
|
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
|
||||||
|
@ -121,18 +116,14 @@ public class MaskedReductionUtil {
|
||||||
//Mask: [minibatch, tsLength]
|
//Mask: [minibatch, tsLength]
|
||||||
//Epsilon: [minibatch, vectorSize]
|
//Epsilon: [minibatch, vectorSize]
|
||||||
|
|
||||||
|
mask = mask.castTo(input.dataType());
|
||||||
|
|
||||||
switch (poolingType) {
|
switch (poolingType) {
|
||||||
case MAX:
|
case MAX:
|
||||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
INDArray negInfMask = mask.rsub(1.0);
|
||||||
INDArray negInfMask;
|
|
||||||
if(mask.dataType() == DataType.BOOL){
|
|
||||||
negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType());
|
|
||||||
} else {
|
|
||||||
negInfMask = mask.rsub(1.0);
|
|
||||||
}
|
|
||||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||||
|
|
||||||
INDArray withInf = Nd4j.createUninitialized(input.shape());
|
INDArray withInf = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||||
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
||||||
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
||||||
|
|
||||||
|
@ -145,7 +136,7 @@ public class MaskedReductionUtil {
|
||||||
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
||||||
//With masking: N differs for different time series
|
//With masking: N differs for different time series
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(input.shape(), 'f');
|
INDArray out = Nd4j.createUninitialized(input.dataType(), input.shape(), 'f');
|
||||||
|
|
||||||
//Broadcast copy op, then divide and mask to 0 as appropriate
|
//Broadcast copy op, then divide and mask to 0 as appropriate
|
||||||
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
|
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
|
||||||
|
@ -162,7 +153,7 @@ public class MaskedReductionUtil {
|
||||||
|
|
||||||
case PNORM:
|
case PNORM:
|
||||||
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
|
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
|
||||||
INDArray masked2 = Nd4j.createUninitialized(input.shape());
|
INDArray masked2 = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||||
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
|
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
|
||||||
|
|
||||||
INDArray abs = Transforms.abs(masked2, true);
|
INDArray abs = Transforms.abs(masked2, true);
|
||||||
|
|
Loading…
Reference in New Issue