diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java new file mode 100644 index 000000000..c80a04c55 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterJavaCode.java @@ -0,0 +1,167 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.learning; + +import org.apache.commons.math3.util.FastMath; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; +import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import java.util.Map; + +public class UpdaterJavaCode { + + private UpdaterJavaCode(){ } + + public static void applyAdaDeltaUpdater(INDArray gradient, INDArray msg, INDArray msdx, double rho, double epsilon){ + + //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf + //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t + msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho)); + + //Calculate update: + //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t + //Note: negative is applied in the DL4J step function: params -= update rather than params += update + INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false); + INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false); + INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t)); + + //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 + msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); + } + + + public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, + double epsilon, int iteration){ + + INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); + m.muli(beta1).addi(oneMinusBeta1Grad); + + INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); + v.muli(beta2).addi(oneMinusBeta2GradSquared); + + double beta1t = FastMath.pow(beta1, iteration + 1); + double beta2t = FastMath.pow(beta2, iteration + 1); + + double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); + if (Double.isNaN(alphat) || alphat == 0.0) + alphat = epsilon; + INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon); + + gradient.assign(m).muli(alphat).divi(sqrtV); + } + + public static void applyAdaMaxUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, + double epsilon, int iteration){ + + //m = B_1 * m + (1-B_1)*grad + m.muli(beta1).addi(gradient.mul(1 - beta1)); + + //u = max(B_2 * u, |grad|) + v.muli(beta2); + Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later + Nd4j.getExecutioner().exec(new Max(v, gradient, v)); + + double beta1t = FastMath.pow(beta1, iteration + 1); + + double alphat = learningRate / (1.0 - beta1t); + if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) { + alphat = epsilon; + } + + v.addi(1e-32); // prevent NaNs in params + gradient.assign(m).muli(alphat).divi(v); + } + + public static void applyAmsGradUpdater(INDArray gradient, INDArray m, INDArray v, INDArray vHat, double learningRate, double beta1, double beta2, + double epsilon, int iteration){ + //m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3 + INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); + m.muli(beta1).addi(oneMinusBeta1Grad); + + //v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3 + INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); + v.muli(beta2).addi(oneMinusBeta2GradSquared); + + double beta1t = FastMath.pow(beta1, iteration + 1); + double beta2t = FastMath.pow(beta2, iteration + 1); + + //vHat_t = max(vHat_{t-1}, v_t) + Transforms.max(vHat, v, false); + + double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); + if (Double.isNaN(alphat) || alphat == 0.0) + alphat = epsilon; + + //gradient array contains: sqrt(vHat) + eps + Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon); + + //gradient = alphat * m_t / (sqrt(vHat) + eps) + gradient.rdivi(m).muli(alphat); + } + + public static void applyNadamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2, + double epsilon, int iteration){ + + INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); + m.muli(beta1).addi(oneMinusBeta1Grad); + + INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2); + v.muli(beta2).addi(oneMinusBeta2GradSquared); + + double beta1t = FastMath.pow(beta1, iteration + 1); + + INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t); + INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t); + + INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate); + + INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon); + + gradient.assign(alphat).divi(sqrtV); + } + + public static void applyNesterovsUpdater(INDArray gradient, INDArray v, double lr, double momentum){ + //reference https://cs231n.github.io/neural-networks-3/#sgd 2nd equation + //DL4J default is negative step function thus we flipped the signs: + // x += mu * v_prev + (-1 - mu) * v + //i.e., we do params -= updatedGradient, not params += updatedGradient + + //v = mu * v - lr * gradient + INDArray vPrev = v.dup('c'); + v.muli(momentum).subi(gradient.dup('c').muli(lr)); //Modify state array in-place + + /* + Next line is equivalent to: + INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); + gradient.assign(ret); + */ + Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); + } + + public static void applyRmsProp(INDArray gradient, INDArray lastGradient, double learningRate, double rmsDecay, double epsilon){ + lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay)); + // lr * gradient / (sqrt(cache) + 1e-8) + gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup('c'), false).addi(epsilon)); + } + + public static void applySgd(INDArray gradient, double lr){ + gradient.muli(lr); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java new file mode 100644 index 000000000..e8df8d7ae --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -0,0 +1,327 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.learning; + +import org.junit.Test; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.learning.config.*; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class UpdaterValidation extends BaseNd4jTest { + + public UpdaterValidation(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void testAdaDeltaUpdater(){ + double rho = 0.95; + double epsilon = 1e-6; + + INDArray msg = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray msdx = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("msg", msg.dup()); + state.put("msdx", msdx.dup()); + AdaDeltaUpdater u = (AdaDeltaUpdater) new AdaDelta(rho,epsilon).instantiate(state, true); + + assertEquals(msg, state.get("msg")); + assertEquals(msdx, state.get("msdx")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon); + + u.applyUpdater(g2, i, 0); + + assertEquals(msg, state.get("msg")); + assertEquals(msdx, state.get("msdx")); + assertEquals(g1, g2); + } + } + + + @Test + public void testAdamUpdater(){ + + double lr = 1e-3; + double beta1 = 0.9; + double beta2 = 0.999; + double eps = 1e-8; + + INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("M", m.dup()); + state.put("V", v.dup()); + AdamUpdater u = (AdamUpdater) new Adam(lr, beta1, beta2, eps).instantiate(state, true); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i); + + u.applyUpdater(g2, i, 0); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + assertEquals(g1, g2); + } + } + + @Test + public void testAdaMaxUpdater(){ + double lr = 1e-3; + double beta1 = 0.9; + double beta2 = 0.999; + double eps = 1e-8; + + INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("M", m.dup()); + state.put("V", v.dup()); + AdaMaxUpdater u = (AdaMaxUpdater) new AdaMax(lr, beta1, beta2, eps).instantiate(state, true); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i); + + u.applyUpdater(g2, i, 0); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + assertEquals(g1, g2); + } + } + + @Test + public void testAmsGradUpdater(){ + double lr = 1e-3; + double beta1 = 0.9; + double beta2 = 0.999; + double eps = 1e-8; + + INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray vH = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("M", m.dup()); + state.put("V", v.dup()); + state.put("V_HAT", vH.dup()); + AMSGradUpdater u = (AMSGradUpdater) new AMSGrad(lr, beta1, beta2, eps).instantiate(state, true); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + assertEquals(vH, state.get("V_HAT")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i); + + u.applyUpdater(g2, i, 0); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + assertEquals(vH, state.get("V_HAT")); + assertEquals(g1, g2); + } + } + + @Test + public void testNadamUpdater(){ + + double lr = 1e-3; + double beta1 = 0.9; + double beta2 = 0.999; + double eps = 1e-8; + + INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5); + INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("M", m.dup()); + state.put("V", v.dup()); + NadamUpdater u = (NadamUpdater) new Nadam(lr, beta1, beta2, eps).instantiate(state, true); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i); + + u.applyUpdater(g2, i, 0); + + assertEquals(m, state.get("M")); + assertEquals(v, state.get("V")); + assertEquals(g1, g2); + } + } + + @Test + public void testNesterovUpdater(){ + + double lr = 0.1; + double momentum = 0.9; + + INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("V", v.dup()); + NesterovsUpdater u = (NesterovsUpdater) new Nesterovs(lr, momentum).instantiate(state, true); + + assertEquals(v, state.get("V")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum); + + u.applyUpdater(g2, i, 0); + + assertEquals(v, state.get("V")); + assertEquals(g1, g2); + } + } + + @Test + public void testRmsPropUpdater(){ + + double lr = 0.1; + double decay = 0.95; + double eps = 1e-8; + + INDArray g = Nd4j.zeros(DataType.DOUBLE, 1, 5); + + Map state = new HashMap<>(); + state.put("G", g.dup()); + RmsPropUpdater u = (RmsPropUpdater) new RmsProp(lr, decay, eps).instantiate(state, true); + + assertEquals(g, state.get("G")); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps); + + u.applyUpdater(g2, i, 0); + + assertEquals(g, state.get("G")); + assertEquals(g1, g2); + } + } + + @Test + public void testSgdUpdater(){ + double lr = 0.1; + + SgdUpdater u = (SgdUpdater) new Sgd(lr).instantiate((Map)null, true); + + for( int i=0; i<3; i++ ) { + INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); + INDArray g2 = g1.dup(); + + UpdaterJavaCode.applySgd(g1, lr); + + u.applyUpdater(g2, i, 0); + assertEquals(g1, g2); + } + } + + + /* + @Test + public void createUpdaterTestCases(){ + Nd4j.create(1); + Nd4j.getRandom().setSeed(12345); + + int size = 5; + + for(boolean random : new boolean[]{false, true}) { + System.out.println("/////////////////////////////// " + (random ? "RANDOM TEST CASES" : "LINSPACE TEST CASES") + " ///////////////////////////////" ); + + for (IUpdater u : new IUpdater[]{new AdaDelta(), new Adam(), new AdaMax(), new AMSGrad(), new Nadam(), new Nesterovs(), new RmsProp(), new Sgd()}) { + + System.out.println(" ===== " + u + " ====="); + + long ss = u.stateSize(size); + INDArray state = ss > 0 ? Nd4j.create(DataType.DOUBLE, 1, ss) : null; + GradientUpdater gu = u.instantiate(state, true); + + System.out.println("Initial state:"); + Map m = gu.getState(); + for (String s : m.keySet()) { + System.out.println("state: " + s + " - " + m.get(s).toStringFull()); + } + + for (int i = 0; i < 3; i++) { + System.out.println("Iteration: " + i); + INDArray in; + if(random){ + in = Nd4j.rand(DataType.DOUBLE, 1, 5); + } else { + in = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1, 5); + } + + System.out.println("grad: " + in.toStringFull()); + gu.applyUpdater(in, 0, 0); + System.out.println("update: " + in.toStringFull()); + + m = gu.getState(); + for (String s : m.keySet()) { + System.out.println("state: " + s + " - " + m.get(s).toStringFull()); + } + } + } + } + } + */ +}