AdaBelief updater: it was agreed to modify changes on the copy of AdamUpdater. This way we can improve it later.
https://arxiv.org/pdf/2010.07468.pdf Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
c523c4f0c7
commit
a4efb4d4e9
|
@ -0,0 +1,96 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
// @author Abdelrauf(rauf@konduit.ai)
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(adabelief_updater, 3, 3, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateU = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateM = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateU = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateM = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
// todo maybe we need an error like on Java side
|
||||||
|
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADABELIEF UPDATER OP: input state V must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADABELIEF UPDATER OP: input state M must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "ADABELIEF UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
const auto lr = INPUT_VARIABLE(3);
|
||||||
|
const auto beta1 = INPUT_VARIABLE(4);
|
||||||
|
const auto beta2 = INPUT_VARIABLE(5);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(6);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "ADABELIEF UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(beta1->isScalar(), 0, "ADABELIEF UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||||
|
REQUIRE_TRUE(beta2->isScalar(), 0, "ADABELIEF UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADABELIEF UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dBeta1 = beta1->e<double>(0);
|
||||||
|
dBeta2 = beta2->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dBeta1 = T_ARG(1);
|
||||||
|
dBeta2 = T_ARG(2);
|
||||||
|
dEpsilon = T_ARG(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAdaBelief(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(adabelief_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -144,6 +144,29 @@ namespace sd {
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adam_updater)
|
#if NOT_EXCLUDED(OP_adam_updater)
|
||||||
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// AdaBelief
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V
|
||||||
|
* 2 - gradient state M
|
||||||
|
* Optional :
|
||||||
|
* 3 - scalar learning rate value
|
||||||
|
* 4 - beta 1 value
|
||||||
|
* 5 - beta 2 value
|
||||||
|
* 6 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - beta 1 value
|
||||||
|
* 2 - beta 2 value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* I args
|
||||||
|
* 0 - iteration
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_adabelief_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(adabelief_updater, 3, 3, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
// AdaDelta
|
// AdaDelta
|
||||||
/* Input arrays :
|
/* Input arrays :
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
// @author Abdelrauf (rauf@konduit.ai)
|
||||||
|
|
||||||
|
// https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void adaBeliefUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update,
|
||||||
|
NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||||
|
const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initU = initStateU.bufferAsT<T>();
|
||||||
|
const T* initM = initStateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stU = stateU.bufferAsT<T>();
|
||||||
|
T* stM = stateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
const T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||||
|
|
||||||
|
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
|
||||||
|
stU[i] = beta2 * initU[i] + (grad[i] - stM[i]) * (grad[i] - stM[i]) * (1 - beta2) + epsilon;
|
||||||
|
|
||||||
|
up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo());
|
||||||
|
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo());
|
||||||
|
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo());
|
||||||
|
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo());
|
||||||
|
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords);
|
||||||
|
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords);
|
||||||
|
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords);
|
||||||
|
const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords);
|
||||||
|
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords);
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
|
||||||
|
|
||||||
|
up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
// @author Abdelrauf (rauf@konduit.ai)
|
||||||
|
|
||||||
|
// https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void adaBeliefUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
|
||||||
|
const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV,
|
||||||
|
const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||||
|
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initU = reinterpret_cast<const T*>(vinv);
|
||||||
|
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stU = reinterpret_cast<T*>(vstV);
|
||||||
|
auto stM = reinterpret_cast<T*>(vstM);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T epsilonT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||||
|
|
||||||
|
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
|
||||||
|
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
|
||||||
|
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||||
|
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||||
|
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||||
|
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering){
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||||
|
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||||
|
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||||
|
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
|
||||||
|
|
||||||
|
up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void adaBeliefUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
|
||||||
|
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
adaBeliefUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, * stream>>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
|
||||||
|
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||||
|
const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "adamUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(),
|
||||||
|
initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(),
|
||||||
|
update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(),
|
||||||
|
stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -40,7 +40,7 @@ namespace helpers {
|
||||||
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
|
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
|
||||||
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
|
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1057,6 +1057,76 @@ TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) {
|
||||||
ASSERT_TRUE(stateM.isSameShape(results.at(2)));
|
ASSERT_TRUE(stateM.isSameShape(results.at(2)));
|
||||||
ASSERT_TRUE(stateM.equalsTo(results.at(2)));
|
ASSERT_TRUE(stateM.equalsTo(results.at(2)));
|
||||||
}
|
}
|
||||||
|
//
|
||||||
|
TEST_F(DeclarableOpsTests18, TestUpdaterAdaBelief1) {
|
||||||
|
//here is the python code used for generating test numbers
|
||||||
|
//import numpy as np
|
||||||
|
//alpha=0.001
|
||||||
|
//beta1=0.9
|
||||||
|
//beta2=0.999
|
||||||
|
//epsilon=1.e-8
|
||||||
|
//#https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
//def update( t, w, gradW, mt, st):
|
||||||
|
// mt = beta1* mt + (1- beta1)*gradW
|
||||||
|
// st = beta2* st + (1- beta2)*((gradW-mt)**2) + epsilon
|
||||||
|
// mt_corr = mt/(1- beta1**t)
|
||||||
|
// st_corr = st/(1- beta2**t)
|
||||||
|
// upW= alpha*(mt_corr/(np.sqrt(st_corr)+epsilon))
|
||||||
|
// w = w - upW
|
||||||
|
// return ( w, upW, mt, st )
|
||||||
|
//#if you want to test with more precision np.set_printoptions(precision=9)
|
||||||
|
//grad = np.array([1,2,3,4,5], dtype = np.float32)
|
||||||
|
//w=np.zeros(5, dtype = np.float32)
|
||||||
|
//mt=np.zeros(5, dtype = np.float32)
|
||||||
|
//st = np.zeros(5, dtype = np.float32)
|
||||||
|
//for t in range(1,4):
|
||||||
|
// w, upW, mt, st = update(t,w,grad, mt,st )
|
||||||
|
// print(f"---{t}----")
|
||||||
|
// print(f"update {upW}")
|
||||||
|
// print(f" s state {st} ")
|
||||||
|
// print(f" m state {mt} ")
|
||||||
|
|
||||||
|
|
||||||
|
NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32);
|
||||||
|
NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
|
||||||
|
NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray update('c', { 1, 5 }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
sd::ops::adabelief_updater op;
|
||||||
|
auto t=0;
|
||||||
|
Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { });
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
NDArray updateExp0('c', { 1, 5 }, { 0.0011111f, 0.00111111f, 0.00111111f, 0.00111111f, 0.00111111f }, DataType::FLOAT32);
|
||||||
|
NDArray stateV('c', { 1, 5 }, { 0.00081001f, 0.00324001f, 0.00729001f, 0.01296001f, 0.02025001f }, DataType::FLOAT32);
|
||||||
|
NDArray stateM0('c', { 1, 5 }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, DataType::FLOAT32);
|
||||||
|
ASSERT_TRUE(update.equalsTo(updateExp0));
|
||||||
|
ASSERT_TRUE(initU.equalsTo(stateV));
|
||||||
|
ASSERT_TRUE(initM.equalsTo(stateM0));
|
||||||
|
t=1;
|
||||||
|
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { t});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
NDArray updateExp1('c', { 1, 5 }, { 0.001168f, 0.001168f, 0.001168f, 0.001168f, 0.001168f}, DataType::FLOAT32);
|
||||||
|
NDArray stateV1('c', { 1, 5 }, { 0.00146531f, 0.00586118f, 0.01318763f, 0.02344466f, 0.03663227f }, DataType::FLOAT32);
|
||||||
|
NDArray stateM1('c', { 1, 5 }, { 0.19f, 0.38f, 0.57000005f, 0.76f, 0.95f }, DataType::FLOAT32);
|
||||||
|
ASSERT_TRUE(update.equalsTo(updateExp1));
|
||||||
|
ASSERT_TRUE(initU.equalsTo(stateV1));
|
||||||
|
ASSERT_TRUE(initM.equalsTo(stateM1));
|
||||||
|
t=2;
|
||||||
|
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, {t});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
NDArray updateExp2('c', { 1, 5 }, { 0.00122557f, 0.00122558f, 0.00122558f, 0.00122558f, 0.00122558f }, DataType::FLOAT32);
|
||||||
|
NDArray stateV2('c', { 1, 5 }, { 0.0019953f, 0.00798109f, 0.01795742f, 0.03192428f, 0.04988168f }, DataType::FLOAT32);
|
||||||
|
NDArray stateM2('c', { 1, 5 }, { 0.271f, 0.542f, 0.813f, 1.084f, 1.355f }, DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_TRUE(update.equalsTo(updateExp2));
|
||||||
|
ASSERT_TRUE(initU.equalsTo(stateV2));
|
||||||
|
ASSERT_TRUE(initM.equalsTo(stateM2));
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) {
|
TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
//https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
|
||||||
|
public class AdaBeliefUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdaBeliefUpdater() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
addInputArgument(gradients, stateU, stateM);
|
||||||
|
addOutputArgument(updates, updatedStateU, updatedStateM);
|
||||||
|
addTArgument(lr, beta1, beta2, epsilon);
|
||||||
|
addIArgument(iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "adabelief_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.learning;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaBelief;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
//https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
|
||||||
|
public static final String M_STATE = "M";
|
||||||
|
public static final String S_STATE = "S";
|
||||||
|
|
||||||
|
private AdaBelief config;
|
||||||
|
private INDArray m, s; // moving avg & sqrd gradients
|
||||||
|
|
||||||
|
private char gradientReshapeOrder;
|
||||||
|
|
||||||
|
public AdaBeliefUpdater(AdaBelief config) {
|
||||||
|
this.config = config;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setState(@NonNull Map<String, INDArray> stateMap, boolean initialize) {
|
||||||
|
if(!stateMap.containsKey(M_STATE) || !stateMap.containsKey(S_STATE) || stateMap.size() != 2){
|
||||||
|
throw new IllegalStateException("State map should contain only keys [" + M_STATE + "," + S_STATE + "] but has keys " + stateMap.keySet());
|
||||||
|
}
|
||||||
|
this.m = stateMap.get(M_STATE);
|
||||||
|
this.s = stateMap.get(S_STATE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, INDArray> getState() {
|
||||||
|
Map<String,INDArray> r = new HashMap<>();
|
||||||
|
r.put(M_STATE, m);
|
||||||
|
r.put(S_STATE, s);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
|
||||||
|
if (!viewArray.isRowVector())
|
||||||
|
throw new IllegalArgumentException("Invalid input: expect row vector input");
|
||||||
|
if (initialize)
|
||||||
|
viewArray.assign(0);
|
||||||
|
long length = viewArray.length();
|
||||||
|
this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
|
||||||
|
this.s = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
|
||||||
|
|
||||||
|
//Reshape to match the expected shape of the input gradient arrays
|
||||||
|
this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
|
||||||
|
this.s = Shape.newShapeNoCopy(this.s, gradientShape, gradientOrder == 'f');
|
||||||
|
if (m == null || s == null)
|
||||||
|
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
|
||||||
|
|
||||||
|
this.gradientReshapeOrder = gradientOrder;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the update based on the given gradient
|
||||||
|
*
|
||||||
|
* @param gradient the gradient to get the update for
|
||||||
|
* @param iteration
|
||||||
|
* @return the gradient
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||||
|
if (m == null || s == null)
|
||||||
|
throw new IllegalStateException("Updater has not been initialized with view state");
|
||||||
|
|
||||||
|
double beta1 = config.getBeta1();
|
||||||
|
double beta2 = config.getBeta2();
|
||||||
|
double learningRate = config.getLearningRate(iteration, epoch);
|
||||||
|
double epsilon = config.getEpsilon();
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(gradient, s, m, learningRate, beta1, beta2, epsilon, iteration));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.learning.config;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.learning.AdaBeliefUpdater;
|
||||||
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
|
import org.nd4j.linalg.schedule.ISchedule;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AdaBelief
|
||||||
|
* https://arxiv.org/pdf/2010.07468.pdf
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
@Builder(builderClassName = "Builder")
|
||||||
|
public class AdaBelief implements IUpdater {
|
||||||
|
|
||||||
|
public static final double DEFAULT_LEARNING_RATE = 1e-3;
|
||||||
|
public static final double DEFAULT_EPSILON = 1e-8;
|
||||||
|
public static final double DEFAULT_BETA1_MEAN_DECAY = 0.9;
|
||||||
|
public static final double DEFAULT_BETA2_VAR_DECAY = 0.999;
|
||||||
|
|
||||||
|
@lombok.Builder.Default private double learningRate = DEFAULT_LEARNING_RATE; // learning rate
|
||||||
|
private ISchedule learningRateSchedule;
|
||||||
|
@lombok.Builder.Default private double beta1 = DEFAULT_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
||||||
|
@lombok.Builder.Default private double beta2 = DEFAULT_BETA2_VAR_DECAY; // gradient sqrt decay rate
|
||||||
|
@lombok.Builder.Default private double epsilon = DEFAULT_EPSILON;
|
||||||
|
|
||||||
|
public AdaBelief() {
|
||||||
|
this(DEFAULT_LEARNING_RATE, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY,
|
||||||
|
DEFAULT_EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaBelief(double learningRate){
|
||||||
|
this(learningRate, null, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaBelief(ISchedule learningRateSchedule){
|
||||||
|
this(Double.NaN, learningRateSchedule, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaBelief(double learningRate, double beta1, double beta2, double epsilon) {
|
||||||
|
this(learningRate, null, beta1, beta2, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
private AdaBelief(@JsonProperty("learningRate") double learningRate,
|
||||||
|
@JsonProperty("learningRateSchedule") ISchedule learningRateSchedule,
|
||||||
|
@JsonProperty("beta1") double beta1,
|
||||||
|
@JsonProperty("beta2") double beta2,
|
||||||
|
@JsonProperty("epsilon") double epsilon){
|
||||||
|
this.learningRate = learningRate;
|
||||||
|
this.learningRateSchedule = learningRateSchedule;
|
||||||
|
this.beta1 = beta1;
|
||||||
|
this.beta2 = beta2;
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long stateSize(long numParams) {
|
||||||
|
return 2 * numParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
|
||||||
|
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
|
||||||
|
long[] gradientShape = viewArray.shape();
|
||||||
|
gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
|
||||||
|
gradientShape[1] /= 2;
|
||||||
|
u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
|
||||||
|
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
|
||||||
|
u.setState(updaterState, initializeStateArrays);
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AdaBelief clone() {
|
||||||
|
return new AdaBelief(learningRate, learningRateSchedule, beta1, beta2, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLearningRate(int iteration, int epoch){
|
||||||
|
if(learningRateSchedule != null){
|
||||||
|
return learningRateSchedule.valueAt(iteration, epoch);
|
||||||
|
}
|
||||||
|
return learningRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasLearningRate() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setLrAndSchedule(double lr, ISchedule lrSchedule) {
|
||||||
|
this.learningRate = lr;
|
||||||
|
this.learningRateSchedule = lrSchedule;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Partial builder implementation to give public no-arg constructor
|
||||||
|
public static class Builder {
|
||||||
|
public Builder(){ }
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue