Fixes for global pooling + masking with different mask datatypes (#212)
* Fixes for global pooling + masking with different mask datatypes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Global pooling backprop dtype fixes Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
ddf70ac450
commit
57d5eb473b
|
@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.PoolingType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaskLayerDataTypes(){
|
||||
|
||||
for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE,
|
||||
DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64,
|
||||
DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){
|
||||
INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt);
|
||||
|
||||
for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){
|
||||
|
||||
INDArray in = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||
INDArray label1 = Nd4j.rand(networkDtype, 2, 5);
|
||||
INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10);
|
||||
|
||||
for(PoolingType pt : PoolingType.values()) {
|
||||
//System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.list()
|
||||
.layer(new GlobalPoolingLayer(pt))
|
||||
.layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
net.output(in, false, mask, null);
|
||||
net.output(in, false, mask, null);
|
||||
|
||||
|
||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||
|
||||
.list()
|
||||
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||
net2.init();
|
||||
|
||||
net2.output(in, false, mask, mask);
|
||||
net2.output(in, false, mask, mask);
|
||||
|
||||
net.fit(in, label1, mask, null);
|
||||
net2.fit(in, label2, mask, mask);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr);
|
||||
|
||||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, LayerWorkspaceMgr.noWorkspaces(), ArrayType.FF_WORKING_MEM);
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input.castTo(W.dataType()), workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
|
||||
INDArray act2d = layerConf().getActivationFn().getActivation(input2d.mmul(W).addiRowVector(b), training);
|
||||
if (maskArray != null) {
|
||||
|
|
|
@ -56,6 +56,7 @@ public class MaskedReductionUtil {
|
|||
throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
|
||||
}
|
||||
|
||||
toReduce = toReduce.castTo(dataType);
|
||||
mask = mask.castTo(dataType);
|
||||
|
||||
//Sum pooling: easy. Multiply by mask, then sum as normal
|
||||
|
@ -64,13 +65,7 @@ public class MaskedReductionUtil {
|
|||
|
||||
switch (poolingType) {
|
||||
case MAX:
|
||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
||||
INDArray negInfMask;
|
||||
if(mask.dataType() == DataType.BOOL){
|
||||
negInfMask = Transforms.not(mask).castTo(dataType);
|
||||
} else {
|
||||
negInfMask = mask.rsub(1.0);
|
||||
}
|
||||
INDArray negInfMask = mask.castTo(dataType).rsub(1.0);
|
||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||
|
||||
INDArray withInf = Nd4j.createUninitialized(dataType, toReduce.shape());
|
||||
|
@ -121,18 +116,14 @@ public class MaskedReductionUtil {
|
|||
//Mask: [minibatch, tsLength]
|
||||
//Epsilon: [minibatch, vectorSize]
|
||||
|
||||
mask = mask.castTo(input.dataType());
|
||||
|
||||
switch (poolingType) {
|
||||
case MAX:
|
||||
//TODO This is ugly - replace it with something better... Need something like a Broadcast CAS op
|
||||
INDArray negInfMask;
|
||||
if(mask.dataType() == DataType.BOOL){
|
||||
negInfMask = Transforms.not(mask).castTo(Nd4j.defaultFloatingPointType());
|
||||
} else {
|
||||
negInfMask = mask.rsub(1.0);
|
||||
}
|
||||
INDArray negInfMask = mask.rsub(1.0);
|
||||
BooleanIndexing.replaceWhere(negInfMask, Double.NEGATIVE_INFINITY, Conditions.equals(1.0));
|
||||
|
||||
INDArray withInf = Nd4j.createUninitialized(input.shape());
|
||||
INDArray withInf = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastAddOp(input, negInfMask, withInf, 0, 2));
|
||||
//At this point: all the masked out steps have value -inf, hence can't be the output of the MAX op
|
||||
|
||||
|
@ -145,7 +136,7 @@ public class MaskedReductionUtil {
|
|||
//if out = avg(in,dims) then dL/dIn = 1/N * dL/dOut
|
||||
//With masking: N differs for different time series
|
||||
|
||||
INDArray out = Nd4j.createUninitialized(input.shape(), 'f');
|
||||
INDArray out = Nd4j.createUninitialized(input.dataType(), input.shape(), 'f');
|
||||
|
||||
//Broadcast copy op, then divide and mask to 0 as appropriate
|
||||
Nd4j.getExecutioner().exec(new BroadcastCopyOp(out, epsilon2d, out, 0, 1));
|
||||
|
@ -162,7 +153,7 @@ public class MaskedReductionUtil {
|
|||
|
||||
case PNORM:
|
||||
//Similar to average and sum pooling: there's no N term here, so we can just set the masked values to 0
|
||||
INDArray masked2 = Nd4j.createUninitialized(input.shape());
|
||||
INDArray masked2 = Nd4j.createUninitialized(input.dataType(), input.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastMulOp(input, mask, masked2, 0, 2));
|
||||
|
||||
INDArray abs = Transforms.abs(masked2, true);
|
||||
|
|
Loading…
Reference in New Issue