From 3967e039a5b8dbaea1d5ae4221a068aee89856f8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 17 Apr 2020 14:41:49 +1000 Subject: [PATCH] Switch Java-based updater implementations to C++ ops (#384) Signed-off-by: Alex Black --- .../api/ops/impl/updaters/AmsGradUpdater.java | 7 ++++-- .../api/ops/impl/updaters/NadamUpdater.java | 7 ++++-- .../nd4j/linalg/learning/AMSGradUpdater.java | 25 +++---------------- .../nd4j/linalg/learning/AdaDeltaUpdater.java | 12 +++------ .../nd4j/linalg/learning/AdaGradUpdater.java | 10 ++------ .../nd4j/linalg/learning/AdaMaxUpdater.java | 23 +++++------------ .../org/nd4j/linalg/learning/AdamUpdater.java | 20 +++------------ .../nd4j/linalg/learning/NadamUpdater.java | 19 +++----------- .../linalg/learning/NesterovsUpdater.java | 12 ++------- .../nd4j/linalg/learning/RmsPropUpdater.java | 6 ++--- .../org/nd4j/linalg/learning/SgdUpdater.java | 4 ++- 11 files changed, 40 insertions(+), 105 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java index 35af113ad..5e8db1cfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java @@ -30,11 +30,14 @@ public class AmsGradUpdater extends DynamicCustomOp { // } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + double lr, double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, + @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM, stateH); addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java index ad4f374b7..325c85af5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java @@ -30,11 +30,14 @@ public class NadamUpdater extends DynamicCustomOp { // } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, + double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, + @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, + double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM); addOutputArgument(updates, updatedStateV, updatedStateM); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java index 79907a237..37d1cb01d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,14 +20,12 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import lombok.val; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AMSGrad; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -103,27 +102,11 @@ public class AMSGradUpdater implements GradientUpdater { double epsilon = config.getEpsilon(); //m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3 - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - //v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3 - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - //vHat_t = max(vHat_{t-1}, v_t) - Transforms.max(vHat, v, false); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - //gradient array contains: sqrt(vHat) + eps - Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon); - //gradient = alphat * m_t / (sqrt(vHat) + eps) - gradient.rdivi(m).muli(alphat); + + Nd4j.exec(new AmsGradUpdater(gradient, v, m, vHat, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java index ced2a8c84..6aa7d7ab4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,9 +20,9 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaDelta; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -104,16 +105,11 @@ public class AdaDeltaUpdater implements GradientUpdater { //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t - msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho)); - //Calculate update: //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t //Note: negative is applied in the DL4J step function: params -= update rather than params += update - INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false); - INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false); - INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t)); - //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); + + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java index 09a530a51..ad355d1cc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java @@ -18,16 +18,14 @@ package org.nd4j.linalg.learning; import lombok.Data; -import lombok.EqualsAndHashCode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.AdaGrad; import java.util.Collections; import java.util.Map; -import static org.nd4j.linalg.ops.transforms.Transforms.sqrt; - /** * Vectorized Learning Rate used per Connection Weight @@ -98,10 +96,6 @@ public class AdaGradUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - historicalGradient.addi(gradient.mul(gradient)); - - INDArray sqrtHistory = sqrt(historicalGradient.dup(gradientReshapeOrder), false).addi(epsilon); - // lr * gradient / (sqrt(sumSquaredGradients) + epsilon) - gradient.muli(sqrtHistory.rdivi(learningRate)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(gradient, historicalGradient, learningRate, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 20a908f1e..06fbde54d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,14 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaMax; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -99,22 +97,13 @@ public class AdaMaxUpdater implements GradientUpdater { throw new IllegalStateException("Updater has not been initialized with view state"); //m = B_1 * m + (1-B_1)*grad - m.muli(config.getBeta1()).addi(gradient.mul(1 - config.getBeta1())); - //u = max(B_2 * u, |grad|) - u.muli(config.getBeta2()); - Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later - Nd4j.getExecutioner().exec(new Max(u, gradient, u)); - double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); + double lr = config.getLearningRate(iteration, epoch); + double b1 = config.getBeta1(); + double b2 = config.getBeta2(); + double eps = config.getEpsilon(); - double learningRate = config.getLearningRate(iteration, epoch); - double alphat = learningRate / (1.0 - beta1t); - if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) { - alphat = config.getEpsilon(); - } - - u.addi(1e-32); // prevent NaNs in params - gradient.assign(m).muli(alphat).divi(u); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(gradient, u, m, lr, b1, b2, eps, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java index e68af09f7..e72bfe5a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,12 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -102,20 +102,6 @@ public class AdamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(m).muli(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java index 18a29cc25..6432e288e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -21,6 +22,7 @@ import lombok.NonNull; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Nadam; import org.nd4j.linalg.ops.transforms.Transforms; @@ -101,21 +103,6 @@ public class NadamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - - INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t); - INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t); - - INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate); - - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 64a9a6f87..2a18b78d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,7 +20,6 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -95,16 +95,8 @@ public class NesterovsUpdater implements GradientUpdater { //DL4J default is negative step function thus we flipped the signs: // x += mu * v_prev + (-1 - mu) * v //i.e., we do params -= updatedGradient, not params += updatedGradient - //v = mu * v - lr * gradient - INDArray vPrev = v.dup(gradientReshapeOrder); - v.muli(momentum).subi(gradient.dup(gradientReshapeOrder).muli(learningRate)); //Modify state array in-place - /* - Next line is equivalent to: - INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); - gradient.assign(ret); - */ - Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(gradient, v, learningRate, momentum)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java index e2d68c4bf..866f9ce0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,8 +21,8 @@ import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; import java.util.Map; @@ -85,8 +86,7 @@ public class RmsPropUpdater implements GradientUpdater { double rmsDecay = config.getRmsDecay(); double epsilon = config.getEpsilon(); - lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay)); // lr * gradient / (sqrt(cache) + 1e-8) - gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup(gradientReshapeOrder), false).addi(epsilon)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(gradient, lastGradient, learningRate, rmsDecay, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java index 1eca487c1..a2d0b0214 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,7 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import java.util.Collections; @@ -56,6 +58,6 @@ public class SgdUpdater implements GradientUpdater { @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { double lr = config.getLearningRate(iteration, epoch); - gradient.muli(lr); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(gradient, lr)); } }