parent
4cf2afad2b
commit
2cd4522f94
|
@ -0,0 +1,167 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.learning;
|
||||||
|
|
||||||
|
import org.apache.commons.math3.util.FastMath;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class UpdaterJavaCode {
|
||||||
|
|
||||||
|
private UpdaterJavaCode(){ }
|
||||||
|
|
||||||
|
public static void applyAdaDeltaUpdater(INDArray gradient, INDArray msg, INDArray msdx, double rho, double epsilon){
|
||||||
|
|
||||||
|
//Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf
|
||||||
|
//E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t
|
||||||
|
msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho));
|
||||||
|
|
||||||
|
//Calculate update:
|
||||||
|
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
|
||||||
|
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
|
||||||
|
INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false);
|
||||||
|
INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false);
|
||||||
|
INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t));
|
||||||
|
|
||||||
|
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
|
||||||
|
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void applyAdamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
|
||||||
|
double epsilon, int iteration){
|
||||||
|
|
||||||
|
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||||
|
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||||
|
|
||||||
|
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
|
||||||
|
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||||
|
|
||||||
|
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||||
|
double beta2t = FastMath.pow(beta2, iteration + 1);
|
||||||
|
|
||||||
|
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
|
||||||
|
if (Double.isNaN(alphat) || alphat == 0.0)
|
||||||
|
alphat = epsilon;
|
||||||
|
INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon);
|
||||||
|
|
||||||
|
gradient.assign(m).muli(alphat).divi(sqrtV);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applyAdaMaxUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
|
||||||
|
double epsilon, int iteration){
|
||||||
|
|
||||||
|
//m = B_1 * m + (1-B_1)*grad
|
||||||
|
m.muli(beta1).addi(gradient.mul(1 - beta1));
|
||||||
|
|
||||||
|
//u = max(B_2 * u, |grad|)
|
||||||
|
v.muli(beta2);
|
||||||
|
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
|
||||||
|
Nd4j.getExecutioner().exec(new Max(v, gradient, v));
|
||||||
|
|
||||||
|
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||||
|
|
||||||
|
double alphat = learningRate / (1.0 - beta1t);
|
||||||
|
if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) {
|
||||||
|
alphat = epsilon;
|
||||||
|
}
|
||||||
|
|
||||||
|
v.addi(1e-32); // prevent NaNs in params
|
||||||
|
gradient.assign(m).muli(alphat).divi(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applyAmsGradUpdater(INDArray gradient, INDArray m, INDArray v, INDArray vHat, double learningRate, double beta1, double beta2,
|
||||||
|
double epsilon, int iteration){
|
||||||
|
//m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3
|
||||||
|
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||||
|
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||||
|
|
||||||
|
//v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3
|
||||||
|
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
|
||||||
|
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||||
|
|
||||||
|
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||||
|
double beta2t = FastMath.pow(beta2, iteration + 1);
|
||||||
|
|
||||||
|
//vHat_t = max(vHat_{t-1}, v_t)
|
||||||
|
Transforms.max(vHat, v, false);
|
||||||
|
|
||||||
|
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
|
||||||
|
if (Double.isNaN(alphat) || alphat == 0.0)
|
||||||
|
alphat = epsilon;
|
||||||
|
|
||||||
|
//gradient array contains: sqrt(vHat) + eps
|
||||||
|
Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon);
|
||||||
|
|
||||||
|
//gradient = alphat * m_t / (sqrt(vHat) + eps)
|
||||||
|
gradient.rdivi(m).muli(alphat);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applyNadamUpdater(INDArray gradient, INDArray m, INDArray v, double learningRate, double beta1, double beta2,
|
||||||
|
double epsilon, int iteration){
|
||||||
|
|
||||||
|
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||||
|
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||||
|
|
||||||
|
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2);
|
||||||
|
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||||
|
|
||||||
|
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||||
|
|
||||||
|
INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t);
|
||||||
|
INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t);
|
||||||
|
|
||||||
|
INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate);
|
||||||
|
|
||||||
|
INDArray sqrtV = Transforms.sqrt(v.dup('c'), false).addi(epsilon);
|
||||||
|
|
||||||
|
gradient.assign(alphat).divi(sqrtV);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applyNesterovsUpdater(INDArray gradient, INDArray v, double lr, double momentum){
|
||||||
|
//reference https://cs231n.github.io/neural-networks-3/#sgd 2nd equation
|
||||||
|
//DL4J default is negative step function thus we flipped the signs:
|
||||||
|
// x += mu * v_prev + (-1 - mu) * v
|
||||||
|
//i.e., we do params -= updatedGradient, not params += updatedGradient
|
||||||
|
|
||||||
|
//v = mu * v - lr * gradient
|
||||||
|
INDArray vPrev = v.dup('c');
|
||||||
|
v.muli(momentum).subi(gradient.dup('c').muli(lr)); //Modify state array in-place
|
||||||
|
|
||||||
|
/*
|
||||||
|
Next line is equivalent to:
|
||||||
|
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
|
||||||
|
gradient.assign(ret);
|
||||||
|
*/
|
||||||
|
Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applyRmsProp(INDArray gradient, INDArray lastGradient, double learningRate, double rmsDecay, double epsilon){
|
||||||
|
lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay));
|
||||||
|
// lr * gradient / (sqrt(cache) + 1e-8)
|
||||||
|
gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup('c'), false).addi(epsilon));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void applySgd(INDArray gradient, double lr){
|
||||||
|
gradient.muli(lr);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,327 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.learning;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.learning.config.*;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class UpdaterValidation extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public UpdaterValidation(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdaDeltaUpdater(){
|
||||||
|
double rho = 0.95;
|
||||||
|
double epsilon = 1e-6;
|
||||||
|
|
||||||
|
INDArray msg = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray msdx = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("msg", msg.dup());
|
||||||
|
state.put("msdx", msdx.dup());
|
||||||
|
AdaDeltaUpdater u = (AdaDeltaUpdater) new AdaDelta(rho,epsilon).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(msg, state.get("msg"));
|
||||||
|
assertEquals(msdx, state.get("msdx"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(msg, state.get("msg"));
|
||||||
|
assertEquals(msdx, state.get("msdx"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdamUpdater(){
|
||||||
|
|
||||||
|
double lr = 1e-3;
|
||||||
|
double beta1 = 0.9;
|
||||||
|
double beta2 = 0.999;
|
||||||
|
double eps = 1e-8;
|
||||||
|
|
||||||
|
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("M", m.dup());
|
||||||
|
state.put("V", v.dup());
|
||||||
|
AdamUpdater u = (AdamUpdater) new Adam(lr, beta1, beta2, eps).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdaMaxUpdater(){
|
||||||
|
double lr = 1e-3;
|
||||||
|
double beta1 = 0.9;
|
||||||
|
double beta2 = 0.999;
|
||||||
|
double eps = 1e-8;
|
||||||
|
|
||||||
|
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("M", m.dup());
|
||||||
|
state.put("V", v.dup());
|
||||||
|
AdaMaxUpdater u = (AdaMaxUpdater) new AdaMax(lr, beta1, beta2, eps).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAmsGradUpdater(){
|
||||||
|
double lr = 1e-3;
|
||||||
|
double beta1 = 0.9;
|
||||||
|
double beta2 = 0.999;
|
||||||
|
double eps = 1e-8;
|
||||||
|
|
||||||
|
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray vH = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("M", m.dup());
|
||||||
|
state.put("V", v.dup());
|
||||||
|
state.put("V_HAT", vH.dup());
|
||||||
|
AMSGradUpdater u = (AMSGradUpdater) new AMSGrad(lr, beta1, beta2, eps).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(vH, state.get("V_HAT"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(vH, state.get("V_HAT"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNadamUpdater(){
|
||||||
|
|
||||||
|
double lr = 1e-3;
|
||||||
|
double beta1 = 0.9;
|
||||||
|
double beta2 = 0.999;
|
||||||
|
double eps = 1e-8;
|
||||||
|
|
||||||
|
INDArray m = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("M", m.dup());
|
||||||
|
state.put("V", v.dup());
|
||||||
|
NadamUpdater u = (NadamUpdater) new Nadam(lr, beta1, beta2, eps).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(m, state.get("M"));
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNesterovUpdater(){
|
||||||
|
|
||||||
|
double lr = 0.1;
|
||||||
|
double momentum = 0.9;
|
||||||
|
|
||||||
|
INDArray v = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("V", v.dup());
|
||||||
|
NesterovsUpdater u = (NesterovsUpdater) new Nesterovs(lr, momentum).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(v, state.get("V"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRmsPropUpdater(){
|
||||||
|
|
||||||
|
double lr = 0.1;
|
||||||
|
double decay = 0.95;
|
||||||
|
double eps = 1e-8;
|
||||||
|
|
||||||
|
INDArray g = Nd4j.zeros(DataType.DOUBLE, 1, 5);
|
||||||
|
|
||||||
|
Map<String,INDArray> state = new HashMap<>();
|
||||||
|
state.put("G", g.dup());
|
||||||
|
RmsPropUpdater u = (RmsPropUpdater) new RmsProp(lr, decay, eps).instantiate(state, true);
|
||||||
|
|
||||||
|
assertEquals(g, state.get("G"));
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
assertEquals(g, state.get("G"));
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSgdUpdater(){
|
||||||
|
double lr = 0.1;
|
||||||
|
|
||||||
|
SgdUpdater u = (SgdUpdater) new Sgd(lr).instantiate((Map<String,INDArray>)null, true);
|
||||||
|
|
||||||
|
for( int i=0; i<3; i++ ) {
|
||||||
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
|
INDArray g2 = g1.dup();
|
||||||
|
|
||||||
|
UpdaterJavaCode.applySgd(g1, lr);
|
||||||
|
|
||||||
|
u.applyUpdater(g2, i, 0);
|
||||||
|
assertEquals(g1, g2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
@Test
|
||||||
|
public void createUpdaterTestCases(){
|
||||||
|
Nd4j.create(1);
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
int size = 5;
|
||||||
|
|
||||||
|
for(boolean random : new boolean[]{false, true}) {
|
||||||
|
System.out.println("/////////////////////////////// " + (random ? "RANDOM TEST CASES" : "LINSPACE TEST CASES") + " ///////////////////////////////" );
|
||||||
|
|
||||||
|
for (IUpdater u : new IUpdater[]{new AdaDelta(), new Adam(), new AdaMax(), new AMSGrad(), new Nadam(), new Nesterovs(), new RmsProp(), new Sgd()}) {
|
||||||
|
|
||||||
|
System.out.println(" ===== " + u + " =====");
|
||||||
|
|
||||||
|
long ss = u.stateSize(size);
|
||||||
|
INDArray state = ss > 0 ? Nd4j.create(DataType.DOUBLE, 1, ss) : null;
|
||||||
|
GradientUpdater gu = u.instantiate(state, true);
|
||||||
|
|
||||||
|
System.out.println("Initial state:");
|
||||||
|
Map<String, INDArray> m = gu.getState();
|
||||||
|
for (String s : m.keySet()) {
|
||||||
|
System.out.println("state: " + s + " - " + m.get(s).toStringFull());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
System.out.println("Iteration: " + i);
|
||||||
|
INDArray in;
|
||||||
|
if(random){
|
||||||
|
in = Nd4j.rand(DataType.DOUBLE, 1, 5);
|
||||||
|
} else {
|
||||||
|
in = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
System.out.println("grad: " + in.toStringFull());
|
||||||
|
gu.applyUpdater(in, 0, 0);
|
||||||
|
System.out.println("update: " + in.toStringFull());
|
||||||
|
|
||||||
|
m = gu.getState();
|
||||||
|
for (String s : m.keySet()) {
|
||||||
|
System.out.println("state: " + s + " - " + m.get(s).toStringFull());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
}
|
Loading…
Reference in New Issue