From bd376ca9939eb5bdc60e57d8608b06a40f48cdc1 Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Fri, 24 Apr 2020 18:12:46 +0300 Subject: [PATCH] GRU and GRUBp (#410) * GRU and GRUBp ops added and tested Signed-off-by: Andrii Tuzhykov * minor polishing Signed-off-by: Andrii Tuzhykov * few requested changes Signed-off-by: Andrii Tuzhykov * regenerated namespace + small fix in RnnOpValidation Signed-off-by: Andrii Tuzhykov * Fix bad character in RnnOpValidation Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../org/nd4j/autodiff/samediff/ops/SDRNN.java | 55 +++++++++++++-- .../converters/ImportClassMapping.java | 2 + .../api/ops/impl/layers/recurrent/GRU.java | 67 +++++++++++++++++++ .../api/ops/impl/layers/recurrent/GRUBp.java | 56 ++++++++++++++++ .../org/nd4j/linalg/factory/ops/NDRNN.java | 25 ++++++- .../opvalidation/LayerOpValidation.java | 30 +++++++++ .../opvalidation/RnnOpValidation.java | 2 +- 7 files changed, 227 insertions(+), 10 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index de8148c02..ebb1a025d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -35,6 +35,48 @@ public class SDRNN extends SDOps { super(sameDiff); } + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public SDVariable gru(SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, + SDVariable biases) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + SDValidation.validateNumerical("gru", "Wx", Wx); + SDValidation.validateNumerical("gru", "Wh", Wh); + SDValidation.validateNumerical("gru", "biases", biases); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(sd,x, hLast, Wx, Wh, biases).outputVariable(); + } + + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param name name May be null. Name for the output variable + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public SDVariable gru(String name, SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, + SDVariable biases) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + SDValidation.validateNumerical("gru", "Wx", Wx); + SDValidation.validateNumerical("gru", "Wh", Wh); + SDValidation.validateNumerical("gru", "biases", biases); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(sd,x, hLast, Wx, Wh, biases).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * The GRU cell. Does a single time step operation
* @@ -42,9 +84,9 @@ public class SDRNN extends SDOps { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { - SDValidation.validateNumerical("gru", "x", x); - SDValidation.validateNumerical("gru", "hLast", hLast); + public SDVariable[] gruCell(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gruCell", "x", x); + SDValidation.validateNumerical("gruCell", "hLast", hLast); return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); } @@ -56,9 +98,10 @@ public class SDRNN extends SDOps { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { - SDValidation.validateNumerical("gru", "x", x); - SDValidation.validateNumerical("gru", "hLast", hLast); + public SDVariable[] gruCell(String[] names, SDVariable x, SDVariable hLast, + GRUWeights GRUWeights) { + SDValidation.validateNumerical("gruCell", "x", x); + SDValidation.validateNumerical("gruCell", "hLast", hLast); SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 6af2d462a..62e55a02d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -144,6 +144,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUBp.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java new file mode 100644 index 000000000..0cc62833d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class GRU extends DynamicCustomOp { + + + public GRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable hI, @NonNull SDVariable Wx, @NonNull SDVariable Wh, @NonNull SDVariable biases) { + super(null, sameDiff, new SDVariable[]{x, hI, Wx, Wh, biases}); + + } + + public GRU(@NonNull INDArray x, @NonNull INDArray hI, @NonNull INDArray Wx, @NonNull INDArray Wh, @NonNull INDArray biases) { + super(new INDArray[]{x, hI, Wx, Wh, biases}, null); + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 5, "Expected 5 inputs to GRU: initial cell output, input-to-hidden weights, hidden-to-hidden weights and biases got %s", inputDataTypes); + DataType dt = inputDataTypes.get(1); + for (int i = 0; i < inputDataTypes.size(); i++) { + Preconditions.checkState(inputDataTypes.get(i).isFPType(), "All input types must be a floating point type, got %s", dt); + } + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + return Collections.singletonList(dt); + } + + @Override + public List doDiff(List grads) { + return Arrays.asList(new GRUBp(sameDiff, arg(0), arg(1), arg(2), arg(3), + arg(4), grads.get(0)).outputVariables()); + } + + + @Override + public String opName() { + return "gru"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java new file mode 100644 index 000000000..b667fa811 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class GRUBp extends DynamicCustomOp { + + + public GRUBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable hI, @NonNull SDVariable Wx, @NonNull SDVariable Wh, @NonNull SDVariable biases, @NonNull SDVariable dLdh) { + super(null, sameDiff, new SDVariable[]{x, hI, Wx, Wh, biases, dLdh}); + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + DataType dt = inputDataTypes.get(1); + List list = new ArrayList(); + list.add(dt); + list.add(dt); + list.add(dt); + list.add(dt); + list.add(dt); + return list; + } + + @Override + public String opName() { + return "gru_bp"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java index 9bb7d9640..6dee1ef7e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java @@ -34,6 +34,25 @@ public class NDRNN { public NDRNN() { } + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public INDArray gru(INDArray x, INDArray hLast, INDArray Wx, INDArray Wh, INDArray biases) { + NDValidation.validateNumerical("gru", "x", x); + NDValidation.validateNumerical("gru", "hLast", hLast); + NDValidation.validateNumerical("gru", "Wx", Wx); + NDValidation.validateNumerical("gru", "Wh", Wh); + NDValidation.validateNumerical("gru", "biases", biases); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(x, hLast, Wx, Wh, biases))[0]; + } + /** * The GRU cell. Does a single time step operation
* @@ -41,9 +60,9 @@ public class NDRNN { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { - NDValidation.validateNumerical("gru", "x", x); - NDValidation.validateNumerical("gru", "hLast", hLast); + public INDArray[] gruCell(INDArray x, INDArray hLast, GRUWeights GRUWeights) { + NDValidation.validateNumerical("gruCell", "x", x); + NDValidation.validateNumerical("gruCell", "hLast", hLast); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index c83a55d08..964c82c18 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -38,6 +38,7 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; @@ -1701,6 +1702,35 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err); } + @Test + public void GRUTestCase() { + int bS = 5; + int nIn = 4; + int nOut = 6; + int time = 2; + + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, time, bS, nIn).muli(10)); + SDVariable hLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, nOut)); + SDVariable Wx = sd.var("Wx", Nd4j.randn(DataType.DOUBLE, nIn, 3*nOut)); + SDVariable Wh = sd.var("Wh", Nd4j.randn(DataType.DOUBLE, nOut, 3*nOut)); + SDVariable biases = sd.var("bias", Nd4j.randn(DataType.DOUBLE, 3*nOut)); + + SDVariable out = new GRU(sd, in, hLast, Wx, Wh,biases).outputVariable(); + + long[] outShapes = new long[]{time,bS, nOut}; + assertArrayEquals(new long[]{time,bS, nOut}, out.eval().shape()); + + sd.setLossVariables(out.std(true)); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); + + } + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index c47d02b04..f3c79db65 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation { .cBias(bc) .build(); - SDVariable[] v = sd.rnn().gru(x, hLast, weights); + SDVariable[] v = sd.rnn().gruCell(x, hLast, weights); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.name());