Add updater tests/validation (#319)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-03-16 18:35:15 +11:00 committed by GitHub
parent 4cf2afad2b
commit 2cd4522f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 494 additions and 0 deletions

View File

@ -0,0 +1,167 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.learning;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Map;
public class UpdaterJavaCode {
private UpdaterJavaCode(){ }
public static void applyAdaDeltaUpdater(INDArray gradient, INDArray msg, INDArray msdx, double rho, double epsilon){
//Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf
//E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t
msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho));
//Calculate update:
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false);
INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false);
INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t));
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
}
public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
double epsilon, int iteration){
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
m.muli(beta1).addi(oneMinusBeta1Grad);
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
double beta2t = FastMath.pow(beta2, iteration + 1);
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
if (Double.isNaN(alphat) || alphat == 0.0)
alphat = epsilon;
INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon);
gradient.assign(m).muli(alphat).divi(sqrtV);
}
public static void applyAdaMaxUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
double epsilon, int iteration){
//m = B_1 * m + (1-B_1)*grad
m.muli(beta1).addi(gradient.mul(1 - beta1));
//u = max(B_2 * u, |grad|)
v.muli(beta2);
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
Nd4j.getExecutioner().exec(new Max(v, gradient, v));
double beta1t = FastMath.pow(beta1, iteration + 1);
double alphat = learningRate / (1.0 - beta1t);
if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) {
alphat = epsilon;
}
v.addi(1e-32); // prevent NaNs in params
gradient.assign(m).muli(alphat).divi(v);
}
public static void applyAmsGradUpdater(INDArray gradient, INDArray m, INDArray v, INDArray vHat, double learningRate, double beta1, double beta2,
double epsilon, int iteration){
//m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
m.muli(beta1).addi(oneMinusBeta1Grad);
//v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
double beta2t = FastMath.pow(beta2, iteration + 1);
//vHat_t = max(vHat_{t-1}, v_t)
Transforms.max(vHat, v, false);
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
if (Double.isNaN(alphat) || alphat == 0.0)
alphat = epsilon;
//gradient array contains: sqrt(vHat) + eps
Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon);
//gradient = alphat * m_t / (sqrt(vHat) + eps)
gradient.rdivi(m).muli(alphat);
}
public static void applyNadamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
double epsilon, int iteration){
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
m.muli(beta1).addi(oneMinusBeta1Grad);
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t);
INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t);
INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate);
INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon);
gradient.assign(alphat).divi(sqrtV);
}
public static void applyNesterovsUpdater(INDArray gradient, INDArray v, double lr, double momentum){
//reference https://cs231n.github.io/neural-networks-3/#sgd 2nd equation
//DL4J default is negative step function thus we flipped the signs:
// x += mu * v_prev + (-1 - mu) * v
//i.e., we do params -= updatedGradient, not params += updatedGradient
//v = mu * v - lr * gradient
INDArray vPrev = v.dup('c');
v.muli(momentum).subi(gradient.dup('c').muli(lr)); //Modify state array in-place
/*
Next line is equivalent to:
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
gradient.assign(ret);
*/
Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
}
public static void applyRmsProp(INDArray gradient, INDArray lastGradient, double learningRate, double rmsDecay, double epsilon){
lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay));
// lr * gradient / (sqrt(cache) + 1e-8)
gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup('c'), false).addi(epsilon));
}
public static void applySgd(INDArray gradient, double lr){
gradient.muli(lr);
}
}

View File

@ -0,0 +1,327 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.learning;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.*;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
public class UpdaterValidation extends BaseNd4jTest {
public UpdaterValidation(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testAdaDeltaUpdater(){
double rho = 0.95;
double epsilon = 1e-6;
INDArray msg = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray msdx = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("msg", msg.dup());
state.put("msdx", msdx.dup());
AdaDeltaUpdater u = (AdaDeltaUpdater) new AdaDelta(rho,epsilon).instantiate(state, true);
assertEquals(msg, state.get("msg"));
assertEquals(msdx, state.get("msdx"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon);
u.applyUpdater(g2, i, 0);
assertEquals(msg, state.get("msg"));
assertEquals(msdx, state.get("msdx"));
assertEquals(g1, g2);
}
}
@Test
public void testAdamUpdater(){
double lr = 1e-3;
double beta1 = 0.9;
double beta2 = 0.999;
double eps = 1e-8;
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("M", m.dup());
state.put("V", v.dup());
AdamUpdater u = (AdamUpdater) new Adam(lr, beta1, beta2, eps).instantiate(state, true);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
}
}
@Test
public void testAdaMaxUpdater(){
double lr = 1e-3;
double beta1 = 0.9;
double beta2 = 0.999;
double eps = 1e-8;
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("M", m.dup());
state.put("V", v.dup());
AdaMaxUpdater u = (AdaMaxUpdater) new AdaMax(lr, beta1, beta2, eps).instantiate(state, true);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
}
}
@Test
public void testAmsGradUpdater(){
double lr = 1e-3;
double beta1 = 0.9;
double beta2 = 0.999;
double eps = 1e-8;
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray vH = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("M", m.dup());
state.put("V", v.dup());
state.put("V_HAT", vH.dup());
AMSGradUpdater u = (AMSGradUpdater) new AMSGrad(lr, beta1, beta2, eps).instantiate(state, true);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(vH, state.get("V_HAT"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(vH, state.get("V_HAT"));
assertEquals(g1, g2);
}
}
@Test
public void testNadamUpdater(){
double lr = 1e-3;
double beta1 = 0.9;
double beta2 = 0.999;
double eps = 1e-8;
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("M", m.dup());
state.put("V", v.dup());
NadamUpdater u = (NadamUpdater) new Nadam(lr, beta1, beta2, eps).instantiate(state, true);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
}
}
@Test
public void testNesterovUpdater(){
double lr = 0.1;
double momentum = 0.9;
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("V", v.dup());
NesterovsUpdater u = (NesterovsUpdater) new Nesterovs(lr, momentum).instantiate(state, true);
assertEquals(v, state.get("V"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum);
u.applyUpdater(g2, i, 0);
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
}
}
@Test
public void testRmsPropUpdater(){
double lr = 0.1;
double decay = 0.95;
double eps = 1e-8;
INDArray g = Nd4j.zeros(DataType.DOUBLE, 1, 5);
Map<String,INDArray> state = new HashMap<>();
state.put("G", g.dup());
RmsPropUpdater u = (RmsPropUpdater) new RmsProp(lr, decay, eps).instantiate(state, true);
assertEquals(g, state.get("G"));
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps);
u.applyUpdater(g2, i, 0);
assertEquals(g, state.get("G"));
assertEquals(g1, g2);
}
}
@Test
public void testSgdUpdater(){
double lr = 0.1;
SgdUpdater u = (SgdUpdater) new Sgd(lr).instantiate((Map<String,INDArray>)null, true);
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
UpdaterJavaCode.applySgd(g1, lr);
u.applyUpdater(g2, i, 0);
assertEquals(g1, g2);
}
}
/*
@Test
public void createUpdaterTestCases(){
Nd4j.create(1);
Nd4j.getRandom().setSeed(12345);
int size = 5;
for(boolean random : new boolean[]{false, true}) {
System.out.println("/////////////////////////////// " + (random ? "RANDOM TEST CASES" : "LINSPACE TEST CASES") + " ///////////////////////////////" );
for (IUpdater u : new IUpdater[]{new AdaDelta(), new Adam(), new AdaMax(), new AMSGrad(), new Nadam(), new Nesterovs(), new RmsProp(), new Sgd()}) {
System.out.println(" ===== " + u + " =====");
long ss = u.stateSize(size);
INDArray state = ss > 0 ? Nd4j.create(DataType.DOUBLE, 1, ss) : null;
GradientUpdater gu = u.instantiate(state, true);
System.out.println("Initial state:");
Map<String, INDArray> m = gu.getState();
for (String s : m.keySet()) {
System.out.println("state: " + s + " - " + m.get(s).toStringFull());
}
for (int i = 0; i < 3; i++) {
System.out.println("Iteration: " + i);
INDArray in;
if(random){
in = Nd4j.rand(DataType.DOUBLE, 1, 5);
} else {
in = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1, 5);
}
System.out.println("grad: " + in.toStringFull());
gu.applyUpdater(in, 0, 0);
System.out.println("update: " + in.toStringFull());
m = gu.getState();
for (String s : m.keySet()) {
System.out.println("state: " + s + " - " + m.get(s).toStringFull());
}
}
}
}
}
*/
}