Switch Java-based updater implementations to C++ ops (#384)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-17 14:41:49 +10:00 committed by GitHub
parent bb9cdb251e
commit 3967e039a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 40 additions and 105 deletions

View File

@ -30,11 +30,14 @@ public class AmsGradUpdater extends DynamicCustomOp {
// //
} }
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH,
double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration);
} }
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH,
@NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM,
@NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateV, stateM, stateH); addInputArgument(gradients, stateV, stateM, stateH);
addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH);
addTArgument(lr, beta1, beta2, epsilon); addTArgument(lr, beta1, beta2, epsilon);

View File

@ -30,11 +30,14 @@ public class NadamUpdater extends DynamicCustomOp {
// //
} }
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr,
double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration);
} }
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates,
@NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2,
double epsilon, int iteration) {
addInputArgument(gradients, stateV, stateM); addInputArgument(gradients, stateV, stateM);
addOutputArgument(updates, updatedStateV, updatedStateM); addOutputArgument(updates, updatedStateV, updatedStateM);
addTArgument(lr, beta1, beta2, epsilon); addTArgument(lr, beta1, beta2, epsilon);

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,14 +20,12 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AMSGrad; import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -103,27 +102,11 @@ public class AMSGradUpdater implements GradientUpdater<AMSGrad> {
double epsilon = config.getEpsilon(); double epsilon = config.getEpsilon();
//m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3 //m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
m.muli(beta1).addi(oneMinusBeta1Grad);
//v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3 //v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
double beta2t = FastMath.pow(beta2, iteration + 1);
//vHat_t = max(vHat_{t-1}, v_t) //vHat_t = max(vHat_{t-1}, v_t)
Transforms.max(vHat, v, false);
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
if (Double.isNaN(alphat) || alphat == 0.0)
alphat = epsilon;
//gradient array contains: sqrt(vHat) + eps //gradient array contains: sqrt(vHat) + eps
Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon);
//gradient = alphat * m_t / (sqrt(vHat) + eps) //gradient = alphat * m_t / (sqrt(vHat) + eps)
gradient.rdivi(m).muli(alphat);
Nd4j.exec(new AmsGradUpdater(gradient, v, m, vHat, learningRate, beta1, beta2, epsilon, iteration));
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,9 +20,9 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaDelta; import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -104,16 +105,11 @@ public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
//Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf
//E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t
msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho));
//Calculate update: //Calculate update:
//dX = - g * RMS[delta x]_{t-1} / RMS[g]_t //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t
//Note: negative is applied in the DL4J step function: params -= update rather than params += update //Note: negative is applied in the DL4J step function: params -= update rather than params += update
INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false);
INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false);
INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t));
//Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2
msdx.muli(rho).addi(update.mul(update).muli(1 - rho));
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon));
} }
} }

View File

@ -18,16 +18,14 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import static org.nd4j.linalg.ops.transforms.Transforms.sqrt;
/** /**
* Vectorized Learning Rate used per Connection Weight * Vectorized Learning Rate used per Connection Weight
@ -98,10 +96,6 @@ public class AdaGradUpdater implements GradientUpdater<AdaGrad> {
double learningRate = config.getLearningRate(iteration, epoch); double learningRate = config.getLearningRate(iteration, epoch);
double epsilon = config.getEpsilon(); double epsilon = config.getEpsilon();
historicalGradient.addi(gradient.mul(gradient)); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(gradient, historicalGradient, learningRate, epsilon));
INDArray sqrtHistory = sqrt(historicalGradient.dup(gradientReshapeOrder), false).addi(epsilon);
// lr * gradient / (sqrt(sumSquaredGradients) + epsilon)
gradient.muli(sqrtHistory.rdivi(learningRate));
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -18,14 +19,11 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaMax; import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -99,22 +97,13 @@ public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
throw new IllegalStateException("Updater has not been initialized with view state"); throw new IllegalStateException("Updater has not been initialized with view state");
//m = B_1 * m + (1-B_1)*grad //m = B_1 * m + (1-B_1)*grad
m.muli(config.getBeta1()).addi(gradient.mul(1 - config.getBeta1()));
//u = max(B_2 * u, |grad|) //u = max(B_2 * u, |grad|)
u.muli(config.getBeta2());
Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later
Nd4j.getExecutioner().exec(new Max(u, gradient, u));
double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); double lr = config.getLearningRate(iteration, epoch);
double b1 = config.getBeta1();
double b2 = config.getBeta2();
double eps = config.getEpsilon();
double learningRate = config.getLearningRate(iteration, epoch); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(gradient, u, m, lr, b1, b2, eps, iteration));
double alphat = learningRate / (1.0 - beta1t);
if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) {
alphat = config.getEpsilon();
}
u.addi(1e-32); // prevent NaNs in params
gradient.assign(m).muli(alphat).divi(u);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -18,12 +19,11 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@ -102,20 +102,6 @@ public class AdamUpdater implements GradientUpdater<Adam> {
double learningRate = config.getLearningRate(iteration, epoch); double learningRate = config.getLearningRate(iteration, epoch);
double epsilon = config.getEpsilon(); double epsilon = config.getEpsilon();
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration));
m.muli(beta1).addi(oneMinusBeta1Grad);
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
double beta2t = FastMath.pow(beta2, iteration + 1);
double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
if (Double.isNaN(alphat) || alphat == 0.0)
alphat = epsilon;
INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon);
gradient.assign(m).muli(alphat).divi(sqrtV);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -21,6 +22,7 @@ import lombok.NonNull;
import org.apache.commons.math3.util.FastMath; import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Nadam; import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
@ -101,21 +103,6 @@ public class NadamUpdater implements GradientUpdater<Nadam> {
double learningRate = config.getLearningRate(iteration, epoch); double learningRate = config.getLearningRate(iteration, epoch);
double epsilon = config.getEpsilon(); double epsilon = config.getEpsilon();
INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration));
m.muli(beta1).addi(oneMinusBeta1Grad);
INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2);
v.muli(beta2).addi(oneMinusBeta2GradSquared);
double beta1t = FastMath.pow(beta1, iteration + 1);
INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t);
INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t);
INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate);
INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon);
gradient.assign(alphat).divi(sqrtV);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,7 +20,6 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.Nesterovs;
@ -95,16 +95,8 @@ public class NesterovsUpdater implements GradientUpdater<Nesterovs> {
//DL4J default is negative step function thus we flipped the signs: //DL4J default is negative step function thus we flipped the signs:
// x += mu * v_prev + (-1 - mu) * v // x += mu * v_prev + (-1 - mu) * v
//i.e., we do params -= updatedGradient, not params += updatedGradient //i.e., we do params -= updatedGradient, not params += updatedGradient
//v = mu * v - lr * gradient //v = mu * v - lr * gradient
INDArray vPrev = v.dup(gradientReshapeOrder);
v.muli(momentum).subi(gradient.dup(gradientReshapeOrder).muli(learningRate)); //Modify state array in-place
/* Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(gradient, v, learningRate, momentum));
Next line is equivalent to:
INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1));
gradient.assign(ret);
*/
Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient));
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -20,8 +21,8 @@ import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
@ -85,8 +86,7 @@ public class RmsPropUpdater implements GradientUpdater<RmsProp> {
double rmsDecay = config.getRmsDecay(); double rmsDecay = config.getRmsDecay();
double epsilon = config.getEpsilon(); double epsilon = config.getEpsilon();
lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay));
// lr * gradient / (sqrt(cache) + 1e-8) // lr * gradient / (sqrt(cache) + 1e-8)
gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup(gradientReshapeOrder), false).addi(epsilon)); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(gradient, lastGradient, learningRate, rmsDecay, epsilon));
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,7 @@ package org.nd4j.linalg.learning;
import lombok.Data; import lombok.Data;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import java.util.Collections; import java.util.Collections;
@ -56,6 +58,6 @@ public class SgdUpdater implements GradientUpdater<Sgd> {
@Override @Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) { public void applyUpdater(INDArray gradient, int iteration, int epoch) {
double lr = config.getLearningRate(iteration, epoch); double lr = config.getLearningRate(iteration, epoch);
gradient.muli(lr); Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(gradient, lr));
} }
} }