Merge pull request #9186 from KonduitAI/qwr_adabelief

Ada belief
master
Adam Gibson 2021-02-24 07:28:56 +09:00 committed by GitHub
commit c4b689e5c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 740 additions and 1 deletions

View File

@ -0,0 +1,96 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
// @author Abdelrauf(rauf@konduit.ai)
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(adabelief_updater, 3, 3, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateU = INPUT_VARIABLE(1);
const auto initStateM = INPUT_VARIABLE(2);
auto update = OUTPUT_VARIABLE(0);
auto stateU = OUTPUT_VARIABLE(1);
auto stateM = OUTPUT_VARIABLE(2);
// todo maybe we need an error like on Java side
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADABELIEF UPDATER OP: input state V must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADABELIEF UPDATER OP: input state M must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str());
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
REQUIRE_TRUE(bParamsSupply, 0, "ADABELIEF UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
double dLr, dBeta1, dBeta2, dEpsilon;
if (block.width() > 3) {
const auto lr = INPUT_VARIABLE(3);
const auto beta1 = INPUT_VARIABLE(4);
const auto beta2 = INPUT_VARIABLE(5);
const auto epsilon = INPUT_VARIABLE(6);
REQUIRE_TRUE(lr->isScalar(), 0, "ADABELIEF UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(beta1->isScalar(), 0, "ADABELIEF UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
REQUIRE_TRUE(beta2->isScalar(), 0, "ADABELIEF UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADABELIEF UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dBeta1 = beta1->e<double>(0);
dBeta2 = beta2->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dBeta1 = T_ARG(1);
dBeta2 = T_ARG(2);
dEpsilon = T_ARG(3);
}
helpers::updaterAdaBelief(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
return Status::OK();
}
DECLARE_TYPES(adabelief_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -144,6 +144,29 @@ namespace sd {
*/ */
#if NOT_EXCLUDED(OP_adam_updater) #if NOT_EXCLUDED(OP_adam_updater)
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
#endif
// AdaBelief
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V
* 2 - gradient state M
* Optional :
* 3 - scalar learning rate value
* 4 - beta 1 value
* 5 - beta 2 value
* 6 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - beta 1 value
* 2 - beta 2 value
* 3 - epsilon
* Optional:
* I args
* 0 - iteration
*/
#if NOT_EXCLUDED(OP_adabelief_updater)
DECLARE_CONFIGURABLE_OP(adabelief_updater, 3, 3, true, 0, 0);
#endif #endif
// AdaDelta // AdaDelta
/* Input arrays : /* Input arrays :

View File

@ -0,0 +1,119 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
// @author Abdelrauf (rauf@konduit.ai)
// https://arxiv.org/pdf/2010.07468.pdf
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adaBeliefUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update,
NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
const double dEpsilon, const int nIteration) {
const T* grad = gradient.bufferAsT<T>();
const T* initU = initStateU.bufferAsT<T>();
const T* initM = initStateM.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stU = stateU.bufferAsT<T>();
T* stM = stateM.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
const T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateU.ordering() &&
stateU.ordering() == initStateU.ordering() &&
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
stU[i] = beta2 * initU[i] + (grad[i] - stM[i]) * (grad[i] - stM[i]) * (1 - beta2) + epsilon;
up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo());
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo());
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo());
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo());
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords);
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords);
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords);
const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords);
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords);
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,143 @@
/* ******************************************************************************
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
// @author Abdelrauf (rauf@konduit.ai)
// https://arxiv.org/pdf/2010.07468.pdf
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void adaBeliefUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV,
const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initU = reinterpret_cast<const T*>(vinv);
const auto initM = reinterpret_cast<const T*>(vinm);
auto up = reinterpret_cast<T*>(vz);
auto stU = reinterpret_cast<T*>(vstV);
auto stM = reinterpret_cast<T*>(vstM);
__shared__ Nd4jLong xLen;
__shared__ T epsilonT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
if (!bEWS || !bOrdering){
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
}
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void adaBeliefUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
adaBeliefUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, * stream>>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
}
///////////////////////////////////////////////////////////////////
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
const double dEpsilon, const int nIteration) {
PointersManager manager(context, "adamUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(),
initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(),
update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(),
stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
manager.synchronize();
}
}
}
}

View File

@ -40,7 +40,7 @@ namespace helpers {
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
} }
} }
} }

View File

@ -1057,6 +1057,76 @@ TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) {
ASSERT_TRUE(stateM.isSameShape(results.at(2))); ASSERT_TRUE(stateM.isSameShape(results.at(2)));
ASSERT_TRUE(stateM.equalsTo(results.at(2))); ASSERT_TRUE(stateM.equalsTo(results.at(2)));
} }
//
TEST_F(DeclarableOpsTests18, TestUpdaterAdaBelief1) {
//here is the python code used for generating test numbers
//import numpy as np
//alpha=0.001
//beta1=0.9
//beta2=0.999
//epsilon=1.e-8
//#https://arxiv.org/pdf/2010.07468.pdf
//def update( t, w, gradW, mt, st):
// mt = beta1* mt + (1- beta1)*gradW
// st = beta2* st + (1- beta2)*((gradW-mt)**2) + epsilon
// mt_corr = mt/(1- beta1**t)
// st_corr = st/(1- beta2**t)
// upW= alpha*(mt_corr/(np.sqrt(st_corr)+epsilon))
// w = w - upW
// return ( w, upW, mt, st )
//#if you want to test with more precision np.set_printoptions(precision=9)
//grad = np.array([1,2,3,4,5], dtype = np.float32)
//w=np.zeros(5, dtype = np.float32)
//mt=np.zeros(5, dtype = np.float32)
//st = np.zeros(5, dtype = np.float32)
//for t in range(1,4):
// w, upW, mt, st = update(t,w,grad, mt,st )
// print(f"---{t}----")
// print(f"update {upW}")
// print(f" s state {st} ")
// print(f" m state {mt} ")
NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32);
NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
NDArray update('c', { 1, 5 }, DataType::FLOAT32);
sd::ops::adabelief_updater op;
auto t=0;
Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { });
ASSERT_EQ(ND4J_STATUS_OK, status);
NDArray updateExp0('c', { 1, 5 }, { 0.0011111f, 0.00111111f, 0.00111111f, 0.00111111f, 0.00111111f }, DataType::FLOAT32);
NDArray stateV('c', { 1, 5 }, { 0.00081001f, 0.00324001f, 0.00729001f, 0.01296001f, 0.02025001f }, DataType::FLOAT32);
NDArray stateM0('c', { 1, 5 }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, DataType::FLOAT32);
ASSERT_TRUE(update.equalsTo(updateExp0));
ASSERT_TRUE(initU.equalsTo(stateV));
ASSERT_TRUE(initM.equalsTo(stateM0));
t=1;
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { t});
ASSERT_EQ(ND4J_STATUS_OK, status);
NDArray updateExp1('c', { 1, 5 }, { 0.001168f, 0.001168f, 0.001168f, 0.001168f, 0.001168f}, DataType::FLOAT32);
NDArray stateV1('c', { 1, 5 }, { 0.00146531f, 0.00586118f, 0.01318763f, 0.02344466f, 0.03663227f }, DataType::FLOAT32);
NDArray stateM1('c', { 1, 5 }, { 0.19f, 0.38f, 0.57000005f, 0.76f, 0.95f }, DataType::FLOAT32);
ASSERT_TRUE(update.equalsTo(updateExp1));
ASSERT_TRUE(initU.equalsTo(stateV1));
ASSERT_TRUE(initM.equalsTo(stateM1));
t=2;
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, {t});
ASSERT_EQ(ND4J_STATUS_OK, status);
NDArray updateExp2('c', { 1, 5 }, { 0.00122557f, 0.00122558f, 0.00122558f, 0.00122558f, 0.00122558f }, DataType::FLOAT32);
NDArray stateV2('c', { 1, 5 }, { 0.0019953f, 0.00798109f, 0.01795742f, 0.03192428f, 0.04988168f }, DataType::FLOAT32);
NDArray stateM2('c', { 1, 5 }, { 0.271f, 0.542f, 0.813f, 1.084f, 1.355f }, DataType::FLOAT32);
ASSERT_TRUE(update.equalsTo(updateExp2));
ASSERT_TRUE(initU.equalsTo(stateV2));
ASSERT_TRUE(initM.equalsTo(stateM2));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) {

View File

@ -0,0 +1,49 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
//https://arxiv.org/pdf/2010.07468.pdf
public class AdaBeliefUpdater extends DynamicCustomOp {
public AdaBeliefUpdater() {
}
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
}
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateU, stateM);
addOutputArgument(updates, updatedStateU, updatedStateM);
addTArgument(lr, beta1, beta2, epsilon);
addIArgument(iteration);
}
@Override
public String opName() {
return "adabelief_updater";
}
}

View File

@ -0,0 +1,107 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.learning;
import lombok.Data;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaBelief;
import java.util.HashMap;
import java.util.Map;
//https://arxiv.org/pdf/2010.07468.pdf
@Data
public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
public static final String M_STATE = "M";
public static final String S_STATE = "S";
private AdaBelief config;
private INDArray m, s; // moving avg & sqrd gradients
private char gradientReshapeOrder;
public AdaBeliefUpdater(AdaBelief config) {
this.config = config;
}
@Override
public void setState(@NonNull Map<String, INDArray> stateMap, boolean initialize) {
if(!stateMap.containsKey(M_STATE) || !stateMap.containsKey(S_STATE) || stateMap.size() != 2){
throw new IllegalStateException("State map should contain only keys [" + M_STATE + "," + S_STATE + "] but has keys " + stateMap.keySet());
}
this.m = stateMap.get(M_STATE);
this.s = stateMap.get(S_STATE);
}
@Override
public Map<String, INDArray> getState() {
Map<String,INDArray> r = new HashMap<>();
r.put(M_STATE, m);
r.put(S_STATE, s);
return r;
}
@Override
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
if (!viewArray.isRowVector())
throw new IllegalArgumentException("Invalid input: expect row vector input");
if (initialize)
viewArray.assign(0);
long length = viewArray.length();
this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.s = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
//Reshape to match the expected shape of the input gradient arrays
this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
this.s = Shape.newShapeNoCopy(this.s, gradientShape, gradientOrder == 'f');
if (m == null || s == null)
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
this.gradientReshapeOrder = gradientOrder;
}
/**
* Calculate the update based on the given gradient
*
* @param gradient the gradient to get the update for
* @param iteration
* @return the gradient
*/
@Override
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
if (m == null || s == null)
throw new IllegalStateException("Updater has not been initialized with view state");
double beta1 = config.getBeta1();
double beta2 = config.getBeta2();
double learningRate = config.getLearningRate(iteration, epoch);
double epsilon = config.getEpsilon();
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(gradient, s, m, learningRate, beta1, beta2, epsilon, iteration));
}
}

View File

@ -0,0 +1,132 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.learning.config;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaBeliefUpdater;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.Map;
/**
* AdaBelief
* https://arxiv.org/pdf/2010.07468.pdf
*/
@Data
@Builder(builderClassName = "Builder")
public class AdaBelief implements IUpdater {
public static final double DEFAULT_LEARNING_RATE = 1e-3;
public static final double DEFAULT_EPSILON = 1e-14;
public static final double DEFAULT_BETA1_MEAN_DECAY = 0.9;
public static final double DEFAULT_BETA2_VAR_DECAY = 0.999;
@lombok.Builder.Default private double learningRate = DEFAULT_LEARNING_RATE; // learning rate
private ISchedule learningRateSchedule;
@lombok.Builder.Default private double beta1 = DEFAULT_BETA1_MEAN_DECAY; // gradient moving avg decay rate
@lombok.Builder.Default private double beta2 = DEFAULT_BETA2_VAR_DECAY; // gradient sqrt decay rate
@lombok.Builder.Default private double epsilon = DEFAULT_EPSILON;
public AdaBelief() {
this(DEFAULT_LEARNING_RATE, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY,
DEFAULT_EPSILON);
}
public AdaBelief(double learningRate){
this(learningRate, null, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
}
public AdaBelief(ISchedule learningRateSchedule){
this(Double.NaN, learningRateSchedule, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
}
public AdaBelief(double learningRate, double beta1, double beta2, double epsilon) {
this(learningRate, null, beta1, beta2, epsilon);
}
private AdaBelief(@JsonProperty("learningRate") double learningRate,
@JsonProperty("learningRateSchedule") ISchedule learningRateSchedule,
@JsonProperty("beta1") double beta1,
@JsonProperty("beta2") double beta2,
@JsonProperty("epsilon") double epsilon){
this.learningRate = learningRate;
this.learningRateSchedule = learningRateSchedule;
this.beta1 = beta1;
this.beta2 = beta2;
this.epsilon = epsilon;
}
@Override
public long stateSize(long numParams) {
return 2 * numParams;
}
@Override
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
long[] gradientShape = viewArray.shape();
gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
gradientShape[1] /= 2;
u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
return u;
}
@Override
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
u.setState(updaterState, initializeStateArrays);
return u;
}
@Override
public AdaBelief clone() {
return new AdaBelief(learningRate, learningRateSchedule, beta1, beta2, epsilon);
}
@Override
public double getLearningRate(int iteration, int epoch){
if(learningRateSchedule != null){
return learningRateSchedule.valueAt(iteration, epoch);
}
return learningRate;
}
@Override
public boolean hasLearningRate() {
return true;
}
@Override
public void setLrAndSchedule(double lr, ISchedule lrSchedule) {
this.learningRate = lr;
this.learningRateSchedule = lrSchedule;
}
//Partial builder implementation to give public no-arg constructor
public static class Builder {
public Builder(){ }
}
}