diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaBeliefUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaBeliefUpdater.cpp new file mode 100644 index 000000000..d77622870 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaBeliefUpdater.cpp @@ -0,0 +1,96 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // @author Abdelrauf(rauf@konduit.ai) + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(adabelief_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADABELIEF UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADABELIEF UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "ADABELIEF UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADABELIEF UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "ADABELIEF UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "ADABELIEF UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADABELIEF UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdaBelief(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(adabelief_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/headers/updaters.h b/libnd4j/include/ops/declarable/headers/updaters.h index d8028821e..9e04eb9eb 100644 --- a/libnd4j/include/ops/declarable/headers/updaters.h +++ b/libnd4j/include/ops/declarable/headers/updaters.h @@ -144,6 +144,29 @@ namespace sd { */ #if NOT_EXCLUDED(OP_adam_updater) DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); +#endif + // AdaBelief + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_adabelief_updater) + DECLARE_CONFIGURABLE_OP(adabelief_updater, 3, 3, true, 0, 0); #endif // AdaDelta /* Input arrays : diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaBelief.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaBelief.cpp new file mode 100644 index 000000000..26496fde1 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaBelief.cpp @@ -0,0 +1,119 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// @author Abdelrauf (rauf@konduit.ai) + +// https://arxiv.org/pdf/2010.07468.pdf + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaBeliefUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + stU[i] = beta2 * initU[i] + (grad[i] - stM[i]) * (grad[i] - stM[i]) * (1 - beta2) + epsilon; + + up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords); + const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords); + const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords); + const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon; + + up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaBelief.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaBelief.cu new file mode 100644 index 000000000..20966c8e7 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaBelief.cu @@ -0,0 +1,143 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// @author Abdelrauf (rauf@konduit.ai) + +// https://arxiv.org/pdf/2010.07468.pdf + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaBeliefUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, + const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon; + + up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaBeliefUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + adaBeliefUpdaterCuda<<>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "adamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(), + initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(), + update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(), + stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + + NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h index 2bc6d7d12..0f612c206 100644 --- a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h +++ b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h @@ -40,7 +40,7 @@ namespace helpers { void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); - + void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index 14d37cbf9..d6a3bb41d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -1057,6 +1057,76 @@ TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { ASSERT_TRUE(stateM.isSameShape(results.at(2))); ASSERT_TRUE(stateM.equalsTo(results.at(2))); } +// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaBelief1) { + //here is the python code used for generating test numbers + //import numpy as np + //alpha=0.001 + //beta1=0.9 + //beta2=0.999 + //epsilon=1.e-8 + //#https://arxiv.org/pdf/2010.07468.pdf + //def update( t, w, gradW, mt, st): + // mt = beta1* mt + (1- beta1)*gradW + // st = beta2* st + (1- beta2)*((gradW-mt)**2) + epsilon + // mt_corr = mt/(1- beta1**t) + // st_corr = st/(1- beta2**t) + // upW= alpha*(mt_corr/(np.sqrt(st_corr)+epsilon)) + // w = w - upW + // return ( w, upW, mt, st ) + //#if you want to test with more precision np.set_printoptions(precision=9) + //grad = np.array([1,2,3,4,5], dtype = np.float32) + //w=np.zeros(5, dtype = np.float32) + //mt=np.zeros(5, dtype = np.float32) + //st = np.zeros(5, dtype = np.float32) + //for t in range(1,4): + // w, upW, mt, st = update(t,w,grad, mt,st ) + // print(f"---{t}----") + // print(f"update {upW}") + // print(f" s state {st} ") + // print(f" m state {mt} ") + + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::adabelief_updater op; + auto t=0; + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.0011111f, 0.00111111f, 0.00111111f, 0.00111111f, 0.00111111f }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.00081001f, 0.00324001f, 0.00729001f, 0.01296001f, 0.02025001f }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + t=1; + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { t}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.001168f, 0.001168f, 0.001168f, 0.001168f, 0.001168f}, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00146531f, 0.00586118f, 0.01318763f, 0.02344466f, 0.03663227f }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.19f, 0.38f, 0.57000005f, 0.76f, 0.95f }, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + t=2; + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, {t}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00122557f, 0.00122558f, 0.00122558f, 0.00122558f, 0.00122558f }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.0019953f, 0.00798109f, 0.01795742f, 0.03192428f, 0.04988168f }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.271f, 0.542f, 0.813f, 1.084f, 1.355f }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaBeliefUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaBeliefUpdater.java new file mode 100644 index 000000000..b5d518f28 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaBeliefUpdater.java @@ -0,0 +1,49 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +//https://arxiv.org/pdf/2010.07468.pdf + +public class AdaBeliefUpdater extends DynamicCustomOp { + + public AdaBeliefUpdater() { + } + + public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateU, stateM); + addOutputArgument(updates, updatedStateU, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "adabelief_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java new file mode 100644 index 000000000..7d6dbb16c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java @@ -0,0 +1,107 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.learning; + +import lombok.Data; +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.learning.config.AdaBelief; + +import java.util.HashMap; +import java.util.Map; + +//https://arxiv.org/pdf/2010.07468.pdf + + +@Data +public class AdaBeliefUpdater implements GradientUpdater { + public static final String M_STATE = "M"; + public static final String S_STATE = "S"; + + private AdaBelief config; + private INDArray m, s; // moving avg & sqrd gradients + + private char gradientReshapeOrder; + + public AdaBeliefUpdater(AdaBelief config) { + this.config = config; + } + + + @Override + public void setState(@NonNull Map stateMap, boolean initialize) { + if(!stateMap.containsKey(M_STATE) || !stateMap.containsKey(S_STATE) || stateMap.size() != 2){ + throw new IllegalStateException("State map should contain only keys [" + M_STATE + "," + S_STATE + "] but has keys " + stateMap.keySet()); + } + this.m = stateMap.get(M_STATE); + this.s = stateMap.get(S_STATE); + } + + @Override + public Map getState() { + Map r = new HashMap<>(); + r.put(M_STATE, m); + r.put(S_STATE, s); + return r; + } + + @Override + public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { + if (!viewArray.isRowVector()) + throw new IllegalArgumentException("Invalid input: expect row vector input"); + if (initialize) + viewArray.assign(0); + long length = viewArray.length(); + this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); + this.s = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); + + //Reshape to match the expected shape of the input gradient arrays + this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f'); + this.s = Shape.newShapeNoCopy(this.s, gradientShape, gradientOrder == 'f'); + if (m == null || s == null) + throw new IllegalStateException("Could not correctly reshape gradient view arrays"); + + this.gradientReshapeOrder = gradientOrder; + } + + /** + * Calculate the update based on the given gradient + * + * @param gradient the gradient to get the update for + * @param iteration + * @return the gradient + */ + @Override + public void applyUpdater(INDArray gradient, int iteration, int epoch) { + if (m == null || s == null) + throw new IllegalStateException("Updater has not been initialized with view state"); + + double beta1 = config.getBeta1(); + double beta2 = config.getBeta2(); + double learningRate = config.getLearningRate(iteration, epoch); + double epsilon = config.getEpsilon(); + + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(gradient, s, m, learningRate, beta1, beta2, epsilon, iteration)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaBelief.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaBelief.java new file mode 100644 index 000000000..aa5d3f00d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config/AdaBelief.java @@ -0,0 +1,132 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.nd4j.linalg.learning.config; + +import lombok.Builder; +import lombok.Data; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.learning.AdaBeliefUpdater; +import org.nd4j.linalg.learning.GradientUpdater; +import org.nd4j.linalg.schedule.ISchedule; +import org.nd4j.shade.jackson.annotation.JsonProperty; + +import java.util.Arrays; +import java.util.Map; + +/** + * AdaBelief + * https://arxiv.org/pdf/2010.07468.pdf + */ +@Data +@Builder(builderClassName = "Builder") +public class AdaBelief implements IUpdater { + + public static final double DEFAULT_LEARNING_RATE = 1e-3; + public static final double DEFAULT_EPSILON = 1e-14; + public static final double DEFAULT_BETA1_MEAN_DECAY = 0.9; + public static final double DEFAULT_BETA2_VAR_DECAY = 0.999; + + @lombok.Builder.Default private double learningRate = DEFAULT_LEARNING_RATE; // learning rate + private ISchedule learningRateSchedule; + @lombok.Builder.Default private double beta1 = DEFAULT_BETA1_MEAN_DECAY; // gradient moving avg decay rate + @lombok.Builder.Default private double beta2 = DEFAULT_BETA2_VAR_DECAY; // gradient sqrt decay rate + @lombok.Builder.Default private double epsilon = DEFAULT_EPSILON; + + public AdaBelief() { + this(DEFAULT_LEARNING_RATE, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, + DEFAULT_EPSILON); + } + + public AdaBelief(double learningRate){ + this(learningRate, null, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON); + } + + public AdaBelief(ISchedule learningRateSchedule){ + this(Double.NaN, learningRateSchedule, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON); + } + + public AdaBelief(double learningRate, double beta1, double beta2, double epsilon) { + this(learningRate, null, beta1, beta2, epsilon); + } + + private AdaBelief(@JsonProperty("learningRate") double learningRate, + @JsonProperty("learningRateSchedule") ISchedule learningRateSchedule, + @JsonProperty("beta1") double beta1, + @JsonProperty("beta2") double beta2, + @JsonProperty("epsilon") double epsilon){ + this.learningRate = learningRate; + this.learningRateSchedule = learningRateSchedule; + this.beta1 = beta1; + this.beta2 = beta2; + this.epsilon = epsilon; + } + + @Override + public long stateSize(long numParams) { + return 2 * numParams; + } + + @Override + public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) { + AdaBeliefUpdater u = new AdaBeliefUpdater(this); + long[] gradientShape = viewArray.shape(); + gradientShape = Arrays.copyOf(gradientShape, gradientShape.length); + gradientShape[1] /= 2; + u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray); + return u; + } + + @Override + public GradientUpdater instantiate(Map updaterState, boolean initializeStateArrays) { + AdaBeliefUpdater u = new AdaBeliefUpdater(this); + u.setState(updaterState, initializeStateArrays); + return u; + } + + @Override + public AdaBelief clone() { + return new AdaBelief(learningRate, learningRateSchedule, beta1, beta2, epsilon); + } + + @Override + public double getLearningRate(int iteration, int epoch){ + if(learningRateSchedule != null){ + return learningRateSchedule.valueAt(iteration, epoch); + } + return learningRate; + } + + @Override + public boolean hasLearningRate() { + return true; + } + + @Override + public void setLrAndSchedule(double lr, ISchedule lrSchedule) { + this.learningRate = lr; + this.learningRateSchedule = lrSchedule; + } + + //Partial builder implementation to give public no-arg constructor + public static class Builder { + public Builder(){ } + } +}