Switch Java-based updater implementations to C++ ops (#384)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
bb9cdb251e
commit
3967e039a5
|
@ -30,11 +30,14 @@ public class AmsGradUpdater extends DynamicCustomOp {
|
|||
//
|
||||
}
|
||||
|
||||
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH,
|
||||
double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration);
|
||||
}
|
||||
|
||||
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH,
|
||||
@NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM,
|
||||
@NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
addInputArgument(gradients, stateV, stateM, stateH);
|
||||
addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH);
|
||||
addTArgument(lr, beta1, beta2, epsilon);
|
||||
|
|
|
@ -30,11 +30,14 @@ public class NadamUpdater extends DynamicCustomOp {
|
|||
//
|
||||
}
|
||||
|
||||
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr,
|
||||
double beta1, double beta2, double epsilon, int iteration) {
|
||||
this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||
}
|
||||
|
||||
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates,
|
||||
@NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2,
|
||||
double epsilon, int iteration) {
|
||||
addInputArgument(gradients, stateV, stateM);
|
||||
addOutputArgument(updates, updatedStateV, updatedStateM);
|
||||
addTArgument(lr, beta1, beta2, epsilon);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,14 +20,12 @@ package org.nd4j.linalg.learning;
|
|||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -103,27 +102,11 @@ public class AMSGradUpdater implements GradientUpdater<AMSGrad> {
|
|||
double epsilon = config.getEpsilon();
|
||||
|
||||
//m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3
|
||||
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||
|
||||
//v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3
|
||||
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
|
||||
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||
|
||||
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||
double beta2t = FastMath.pow(beta2, iteration + 1);
|
||||
|
||||
//vHat_t = max(vHat_{t-1}, v_t)
|
||||
Transforms.max(vHat, v, false);
|
||||
|
||||
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
|
||||
if (Double.isNaN(alphat) || alphat == 0.0)
|
||||
alphat = epsilon;
|
||||
|
||||
//gradient array contains: sqrt(vHat) + eps
|
||||
Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon);
|
||||
|
||||
//gradient = alphat * m_t / (sqrt(vHat) + eps)
|
||||
gradient.rdivi(m).muli(alphat);
|
||||
|
||||
Nd4j.exec(new AmsGradUpdater(gradient, v, m, vHat, learningRate, beta1, beta2, epsilon, iteration));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,9 +20,9 @@ package org.nd4j.linalg.learning;
|
|||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.AdaDelta;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -104,16 +105,11 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
|
|||
|
||||
//Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf
|
||||
//E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t
|
||||
msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho));
|
||||
|
||||
//Calculate update:
|
||||
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
|
||||
//Note: negative is applied in the DL4J step function: params -= update rather than params += update
|
||||
INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false);
|
||||
INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false);
|
||||
INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t));
|
||||
|
||||
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
|
||||
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
|
||||
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,16 +18,14 @@ package org.nd4j.linalg.learning;
|
|||
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
|
||||
|
||||
|
||||
/**
|
||||
* Vectorized Learning Rate used per Connection Weight
|
||||
|
@ -98,10 +96,6 @@ public class AdaGradUpdater implements GradientUpdater<AdaGrad> {
|
|||
double learningRate = config.getLearningRate(iteration, epoch);
|
||||
double epsilon = config.getEpsilon();
|
||||
|
||||
historicalGradient.addi(gradient.mul(gradient));
|
||||
|
||||
INDArray sqrtHistory = sqrt(historicalGradient.dup(gradientReshapeOrder), false).addi(epsilon);
|
||||
// lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
|
||||
gradient.muli(sqrtHistory.rdivi(learningRate));
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(gradient, historicalGradient, learningRate, epsilon));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -18,14 +19,11 @@ package org.nd4j.linalg.learning;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.AdaMax;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -99,22 +97,13 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
|
|||
throw new IllegalStateException("Updater has not been initialized with view state");
|
||||
|
||||
//m = B_1 * m + (1-B_1)*grad
|
||||
m.muli(config.getBeta1()).addi(gradient.mul(1 - config.getBeta1()));
|
||||
|
||||
//u = max(B_2 * u, |grad|)
|
||||
u.muli(config.getBeta2());
|
||||
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
|
||||
Nd4j.getExecutioner().exec(new Max(u, gradient, u));
|
||||
|
||||
double beta1t = FastMath.pow(config.getBeta1(), iteration + 1);
|
||||
double lr = config.getLearningRate(iteration, epoch);
|
||||
double b1 = config.getBeta1();
|
||||
double b2 = config.getBeta2();
|
||||
double eps = config.getEpsilon();
|
||||
|
||||
double learningRate = config.getLearningRate(iteration, epoch);
|
||||
double alphat = learningRate / (1.0 - beta1t);
|
||||
if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) {
|
||||
alphat = config.getEpsilon();
|
||||
}
|
||||
|
||||
u.addi(1e-32); // prevent NaNs in params
|
||||
gradient.assign(m).muli(alphat).divi(u);
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(gradient, u, m, lr, b1, b2, eps, iteration));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -18,12 +19,11 @@ package org.nd4j.linalg.learning;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
@ -102,20 +102,6 @@ public class AdamUpdater implements GradientUpdater<Adam> {
|
|||
double learningRate = config.getLearningRate(iteration, epoch);
|
||||
double epsilon = config.getEpsilon();
|
||||
|
||||
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||
|
||||
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
|
||||
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||
|
||||
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||
double beta2t = FastMath.pow(beta2, iteration + 1);
|
||||
|
||||
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
|
||||
if (Double.isNaN(alphat) || alphat == 0.0)
|
||||
alphat = epsilon;
|
||||
INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon);
|
||||
|
||||
gradient.assign(m).muli(alphat).divi(sqrtV);
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -21,6 +22,7 @@ import lombok.NonNull;
|
|||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.Nadam;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
@ -101,21 +103,6 @@ public class NadamUpdater implements GradientUpdater<Nadam> {
|
|||
double learningRate = config.getLearningRate(iteration, epoch);
|
||||
double epsilon = config.getEpsilon();
|
||||
|
||||
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
|
||||
m.muli(beta1).addi(oneMinusBeta1Grad);
|
||||
|
||||
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2);
|
||||
v.muli(beta2).addi(oneMinusBeta2GradSquared);
|
||||
|
||||
double beta1t = FastMath.pow(beta1, iteration + 1);
|
||||
|
||||
INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t);
|
||||
INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t);
|
||||
|
||||
INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate);
|
||||
|
||||
INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon);
|
||||
|
||||
gradient.assign(alphat).divi(sqrtV);
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,7 +20,6 @@ package org.nd4j.linalg.learning;
|
|||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
|
@ -95,16 +95,8 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
|
|||
//DL4J default is negative step function thus we flipped the signs:
|
||||
// x += mu * v_prev + (-1 - mu) * v
|
||||
//i.e., we do params -= updatedGradient, not params += updatedGradient
|
||||
|
||||
//v = mu * v - lr * gradient
|
||||
INDArray vPrev = v.dup(gradientReshapeOrder);
|
||||
v.muli(momentum).subi(gradient.dup(gradientReshapeOrder).muli(learningRate)); //Modify state array in-place
|
||||
|
||||
/*
|
||||
Next line is equivalent to:
|
||||
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
|
||||
gradient.assign(ret);
|
||||
*/
|
||||
Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(gradient, v, learningRate, momentum));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -20,8 +21,8 @@ import lombok.Data;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.RmsProp;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
@ -85,8 +86,7 @@ public class RmsPropUpdater implements GradientUpdater<RmsProp> {
|
|||
double rmsDecay = config.getRmsDecay();
|
||||
double epsilon = config.getEpsilon();
|
||||
|
||||
lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay));
|
||||
// lr * gradient / (sqrt(cache) + 1e-8)
|
||||
gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup(gradientReshapeOrder), false).addi(epsilon));
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(gradient, lastGradient, learningRate, rmsDecay, epsilon));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,6 +20,7 @@ package org.nd4j.linalg.learning;
|
|||
import lombok.Data;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
|
||||
import java.util.Collections;
|
||||
|
@ -56,6 +58,6 @@ public class SgdUpdater implements GradientUpdater<Sgd> {
|
|||
@Override
|
||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||
double lr = config.getLearningRate(iteration, epoch);
|
||||
gradient.muli(lr);
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(gradient, lr));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue