diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index 957d22a08..aebf23673 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -21,14 +21,13 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.factory.Nd4j; @@ -416,4 +415,53 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { } } } + + @Test + public void testMaskLayerDataTypes(){ + + for(DataType dt : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE, + DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64, + DataType.UINT8, DataType.UINT16, DataType.UINT32, DataType.UINT64}){ + INDArray mask = Nd4j.rand(DataType.FLOAT, 2, 10).addi(0.3).castTo(dt); + + for(DataType networkDtype : new DataType[]{DataType.FLOAT16, DataType.BFLOAT16, DataType.FLOAT, DataType.DOUBLE}){ + + INDArray in = Nd4j.rand(networkDtype, 2, 5, 10); + INDArray label1 = Nd4j.rand(networkDtype, 2, 5); + INDArray label2 = Nd4j.rand(networkDtype, 2, 5, 10); + + for(PoolingType pt : PoolingType.values()) { + //System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .list() + .layer(new GlobalPoolingLayer(pt)) + .layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.output(in, false, mask, null); + net.output(in, false, mask, null); + + + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + + .list() + .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) + .build(); + + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + net2.output(in, false, mask, mask); + net2.output(in, false, mask, mask); + + net.fit(in, label1, mask, null); + net2.fit(in, label2, mask, mask); + } + } + } + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 27cc0f9df..93f77ad67 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -131,7 +131,7 @@ public class RnnOutputLayer extends BaseOutputLayer