commit
c4b689e5c8
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
// @author Abdelrauf(rauf@konduit.ai)
|
||||
|
||||
#include <ops/declarable/headers/updaters.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <array/NDArray.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(adabelief_updater, 3, 3, true, 0, 0) {
|
||||
|
||||
const auto gradient = INPUT_VARIABLE(0);
|
||||
const auto initStateU = INPUT_VARIABLE(1);
|
||||
const auto initStateM = INPUT_VARIABLE(2);
|
||||
|
||||
auto update = OUTPUT_VARIABLE(0);
|
||||
auto stateU = OUTPUT_VARIABLE(1);
|
||||
auto stateM = OUTPUT_VARIABLE(2);
|
||||
|
||||
// todo maybe we need an error like on Java side
|
||||
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADABELIEF UPDATER OP: input state V must have the same shape as gradient,"
|
||||
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
|
||||
ShapeUtils::shapeAsString(initStateU->shapeInfo()).c_str());
|
||||
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADABELIEF UPDATER OP: input state M must have the same shape as gradient,"
|
||||
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->shapeInfo()).c_str(),
|
||||
ShapeUtils::shapeAsString(initStateM->shapeInfo()).c_str());
|
||||
|
||||
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
|
||||
|
||||
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||
|
||||
REQUIRE_TRUE(bParamsSupply, 0, "ADABELIEF UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||
|
||||
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||
|
||||
if (block.width() > 3) {
|
||||
const auto lr = INPUT_VARIABLE(3);
|
||||
const auto beta1 = INPUT_VARIABLE(4);
|
||||
const auto beta2 = INPUT_VARIABLE(5);
|
||||
const auto epsilon = INPUT_VARIABLE(6);
|
||||
|
||||
REQUIRE_TRUE(lr->isScalar(), 0, "ADABELIEF UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||
REQUIRE_TRUE(beta1->isScalar(), 0, "ADABELIEF UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||
REQUIRE_TRUE(beta2->isScalar(), 0, "ADABELIEF UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADABELIEF UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||
|
||||
dLr = lr->e<double>(0);
|
||||
dBeta1 = beta1->e<double>(0);
|
||||
dBeta2 = beta2->e<double>(0);
|
||||
dEpsilon = epsilon->e<double>(0);
|
||||
}
|
||||
else {
|
||||
dLr = T_ARG(0);
|
||||
dBeta1 = T_ARG(1);
|
||||
dBeta2 = T_ARG(2);
|
||||
dEpsilon = T_ARG(3);
|
||||
}
|
||||
|
||||
helpers::updaterAdaBelief(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(adabelief_updater) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -144,6 +144,29 @@ namespace sd {
|
|||
*/
|
||||
#if NOT_EXCLUDED(OP_adam_updater)
|
||||
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
|
||||
#endif
|
||||
// AdaBelief
|
||||
/* Input arrays :
|
||||
* 0 - input array with gradients.
|
||||
* 1 - gradient state V
|
||||
* 2 - gradient state M
|
||||
* Optional :
|
||||
* 3 - scalar learning rate value
|
||||
* 4 - beta 1 value
|
||||
* 5 - beta 2 value
|
||||
* 6 - epsilon
|
||||
* Optional:
|
||||
* T args
|
||||
* 0 - scalar learning rate value
|
||||
* 1 - beta 1 value
|
||||
* 2 - beta 2 value
|
||||
* 3 - epsilon
|
||||
* Optional:
|
||||
* I args
|
||||
* 0 - iteration
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_adabelief_updater)
|
||||
DECLARE_CONFIGURABLE_OP(adabelief_updater, 3, 3, true, 0, 0);
|
||||
#endif
|
||||
// AdaDelta
|
||||
/* Input arrays :
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
// @author Abdelrauf (rauf@konduit.ai)
|
||||
|
||||
// https://arxiv.org/pdf/2010.07468.pdf
|
||||
|
||||
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <math/platformmath.h>
|
||||
#include <math/templatemath.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void adaBeliefUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update,
|
||||
NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||
const double dEpsilon, const int nIteration) {
|
||||
|
||||
const T* grad = gradient.bufferAsT<T>();
|
||||
const T* initU = initStateU.bufferAsT<T>();
|
||||
const T* initM = initStateM.bufferAsT<T>();
|
||||
|
||||
T* up = update.bufferAsT<T>();
|
||||
T* stU = stateU.bufferAsT<T>();
|
||||
T* stM = stateM.bufferAsT<T>();
|
||||
|
||||
const T lr = static_cast<T>(dLr);
|
||||
const T beta1 = static_cast<T>(dBeta1);
|
||||
const T beta2 = static_cast<T>(dBeta2);
|
||||
const T epsilon = static_cast<T>(dEpsilon);
|
||||
const T iteration = static_cast<T>(nIteration);
|
||||
|
||||
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||
const T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||
|
||||
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||
epsilonT = epsilon;
|
||||
|
||||
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
|
||||
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||
update.ordering() == stateU.ordering() &&
|
||||
stateU.ordering() == initStateU.ordering() &&
|
||||
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
|
||||
|
||||
if (bEws1 && bSameOrdering) {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i++) {
|
||||
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
|
||||
stU[i] = beta2 * initU[i] + (grad[i] - stM[i]) * (grad[i] - stM[i]) * (1 - beta2) + epsilon;
|
||||
|
||||
up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[i]) + epsilon);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||
return;
|
||||
}
|
||||
|
||||
bool bXZsame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), update.shapeInfo());
|
||||
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateU.shapeInfo());
|
||||
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateU.shapeInfo());
|
||||
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), initStateM.shapeInfo());
|
||||
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.shapeInfo(), stateM.shapeInfo());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
int coords[MAX_RANK];
|
||||
for (auto i = start; i < stop; i++) {
|
||||
shape::index2coordsCPU(start, i, gradient.shapeInfo(), coords);
|
||||
const auto xOffset = shape::getOffset(gradient.shapeInfo(), coords);
|
||||
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.shapeInfo(), coords);
|
||||
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.shapeInfo(), coords);
|
||||
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.shapeInfo(), coords);
|
||||
const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.shapeInfo(), coords);
|
||||
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.shapeInfo(), coords);
|
||||
|
||||
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
|
||||
|
||||
up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||
return;
|
||||
}
|
||||
|
||||
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
/* ******************************************************************************
|
||||
*
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* See the NOTICE file distributed with this work for additional
|
||||
* information regarding copyright ownership.
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
// @author Abdelrauf (rauf@konduit.ai)
|
||||
|
||||
// https://arxiv.org/pdf/2010.07468.pdf
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
#include <math/platformmath.h>
|
||||
#include <math/templatemath.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void adaBeliefUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
|
||||
const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV,
|
||||
const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||
|
||||
const auto grad = reinterpret_cast<const T*>(vx);
|
||||
const auto initU = reinterpret_cast<const T*>(vinv);
|
||||
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||
|
||||
auto up = reinterpret_cast<T*>(vz);
|
||||
auto stU = reinterpret_cast<T*>(vstV);
|
||||
auto stM = reinterpret_cast<T*>(vstM);
|
||||
|
||||
__shared__ Nd4jLong xLen;
|
||||
__shared__ T epsilonT;
|
||||
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(xShapeInfo);
|
||||
|
||||
T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||
T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||
|
||||
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||
epsilonT = epsilon;
|
||||
|
||||
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
|
||||
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
|
||||
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
|
||||
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
|
||||
|
||||
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int coords[MAX_RANK];
|
||||
|
||||
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||
|
||||
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
|
||||
|
||||
if (!bEWS || !bOrdering){
|
||||
|
||||
shape::index2coords(i, xShapeInfo, coords);
|
||||
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||
}
|
||||
|
||||
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||
stU[stUOffset] = beta2 * initU[initUOffset] + (grad[xOffset] - stM[stMOffset]) * (grad[xOffset] - stM[stMOffset]) * (1 - beta2) + epsilon;
|
||||
|
||||
up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
linkage void adaBeliefUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
|
||||
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||
|
||||
const T lr = static_cast<T>(dLr);
|
||||
const T beta1 = static_cast<T>(dBeta1);
|
||||
const T beta2 = static_cast<T>(dBeta2);
|
||||
const T epsilon = static_cast<T>(dEpsilon);
|
||||
const T iteration = static_cast<T>(nIteration);
|
||||
adaBeliefUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, * stream>>>(vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
|
||||
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
|
||||
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||
const double dEpsilon, const int nIteration) {
|
||||
|
||||
PointersManager manager(context, "adamUpdater");
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||
|
||||
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaBeliefUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(),
|
||||
initStateU.specialBuffer(), initStateU.specialShapeInfo(), initStateM.specialBuffer(), initStateM.specialShapeInfo(),
|
||||
update.specialBuffer(), update.specialShapeInfo(), stateU.specialBuffer(), stateU.specialShapeInfo(),
|
||||
stateM.specialBuffer(), stateM.specialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||
|
||||
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,7 +40,7 @@ namespace helpers {
|
|||
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
|
||||
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||
|
||||
void updaterAdaBelief(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1057,6 +1057,76 @@ TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) {
|
|||
ASSERT_TRUE(stateM.isSameShape(results.at(2)));
|
||||
ASSERT_TRUE(stateM.equalsTo(results.at(2)));
|
||||
}
|
||||
//
|
||||
TEST_F(DeclarableOpsTests18, TestUpdaterAdaBelief1) {
|
||||
//here is the python code used for generating test numbers
|
||||
//import numpy as np
|
||||
//alpha=0.001
|
||||
//beta1=0.9
|
||||
//beta2=0.999
|
||||
//epsilon=1.e-8
|
||||
//#https://arxiv.org/pdf/2010.07468.pdf
|
||||
//def update( t, w, gradW, mt, st):
|
||||
// mt = beta1* mt + (1- beta1)*gradW
|
||||
// st = beta2* st + (1- beta2)*((gradW-mt)**2) + epsilon
|
||||
// mt_corr = mt/(1- beta1**t)
|
||||
// st_corr = st/(1- beta2**t)
|
||||
// upW= alpha*(mt_corr/(np.sqrt(st_corr)+epsilon))
|
||||
// w = w - upW
|
||||
// return ( w, upW, mt, st )
|
||||
//#if you want to test with more precision np.set_printoptions(precision=9)
|
||||
//grad = np.array([1,2,3,4,5], dtype = np.float32)
|
||||
//w=np.zeros(5, dtype = np.float32)
|
||||
//mt=np.zeros(5, dtype = np.float32)
|
||||
//st = np.zeros(5, dtype = np.float32)
|
||||
//for t in range(1,4):
|
||||
// w, upW, mt, st = update(t,w,grad, mt,st )
|
||||
// print(f"---{t}----")
|
||||
// print(f"update {upW}")
|
||||
// print(f" s state {st} ")
|
||||
// print(f" m state {mt} ")
|
||||
|
||||
|
||||
NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32);
|
||||
NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
|
||||
NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32);
|
||||
|
||||
NDArray update('c', { 1, 5 }, DataType::FLOAT32);
|
||||
|
||||
sd::ops::adabelief_updater op;
|
||||
auto t=0;
|
||||
Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { });
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
NDArray updateExp0('c', { 1, 5 }, { 0.0011111f, 0.00111111f, 0.00111111f, 0.00111111f, 0.00111111f }, DataType::FLOAT32);
|
||||
NDArray stateV('c', { 1, 5 }, { 0.00081001f, 0.00324001f, 0.00729001f, 0.01296001f, 0.02025001f }, DataType::FLOAT32);
|
||||
NDArray stateM0('c', { 1, 5 }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, DataType::FLOAT32);
|
||||
ASSERT_TRUE(update.equalsTo(updateExp0));
|
||||
ASSERT_TRUE(initU.equalsTo(stateV));
|
||||
ASSERT_TRUE(initM.equalsTo(stateM0));
|
||||
t=1;
|
||||
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { t});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
NDArray updateExp1('c', { 1, 5 }, { 0.001168f, 0.001168f, 0.001168f, 0.001168f, 0.001168f}, DataType::FLOAT32);
|
||||
NDArray stateV1('c', { 1, 5 }, { 0.00146531f, 0.00586118f, 0.01318763f, 0.02344466f, 0.03663227f }, DataType::FLOAT32);
|
||||
NDArray stateM1('c', { 1, 5 }, { 0.19f, 0.38f, 0.57000005f, 0.76f, 0.95f }, DataType::FLOAT32);
|
||||
ASSERT_TRUE(update.equalsTo(updateExp1));
|
||||
ASSERT_TRUE(initU.equalsTo(stateV1));
|
||||
ASSERT_TRUE(initM.equalsTo(stateM1));
|
||||
t=2;
|
||||
status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, {t});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
NDArray updateExp2('c', { 1, 5 }, { 0.00122557f, 0.00122558f, 0.00122558f, 0.00122558f, 0.00122558f }, DataType::FLOAT32);
|
||||
NDArray stateV2('c', { 1, 5 }, { 0.0019953f, 0.00798109f, 0.01795742f, 0.03192428f, 0.04988168f }, DataType::FLOAT32);
|
||||
NDArray stateM2('c', { 1, 5 }, { 0.271f, 0.542f, 0.813f, 1.084f, 1.355f }, DataType::FLOAT32);
|
||||
|
||||
ASSERT_TRUE(update.equalsTo(updateExp2));
|
||||
ASSERT_TRUE(initU.equalsTo(stateV2));
|
||||
ASSERT_TRUE(initM.equalsTo(stateM2));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) {
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
//https://arxiv.org/pdf/2010.07468.pdf
|
||||
|
||||
public class AdaBeliefUpdater extends DynamicCustomOp {
|
||||
|
||||
public AdaBeliefUpdater() {
|
||||
}
|
||||
|
||||
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||
}
|
||||
|
||||
public AdaBeliefUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||
addInputArgument(gradients, stateU, stateM);
|
||||
addOutputArgument(updates, updatedStateU, updatedStateM);
|
||||
addTArgument(lr, beta1, beta2, epsilon);
|
||||
addIArgument(iteration);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "adabelief_updater";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.learning;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.learning.config.AdaBelief;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
//https://arxiv.org/pdf/2010.07468.pdf
|
||||
|
||||
|
||||
@Data
|
||||
public class AdaBeliefUpdater implements GradientUpdater<AdaBelief> {
|
||||
public static final String M_STATE = "M";
|
||||
public static final String S_STATE = "S";
|
||||
|
||||
private AdaBelief config;
|
||||
private INDArray m, s; // moving avg & sqrd gradients
|
||||
|
||||
private char gradientReshapeOrder;
|
||||
|
||||
public AdaBeliefUpdater(AdaBelief config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setState(@NonNull Map<String, INDArray> stateMap, boolean initialize) {
|
||||
if(!stateMap.containsKey(M_STATE) || !stateMap.containsKey(S_STATE) || stateMap.size() != 2){
|
||||
throw new IllegalStateException("State map should contain only keys [" + M_STATE + "," + S_STATE + "] but has keys " + stateMap.keySet());
|
||||
}
|
||||
this.m = stateMap.get(M_STATE);
|
||||
this.s = stateMap.get(S_STATE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, INDArray> getState() {
|
||||
Map<String,INDArray> r = new HashMap<>();
|
||||
r.put(M_STATE, m);
|
||||
r.put(S_STATE, s);
|
||||
return r;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
|
||||
if (!viewArray.isRowVector())
|
||||
throw new IllegalArgumentException("Invalid input: expect row vector input");
|
||||
if (initialize)
|
||||
viewArray.assign(0);
|
||||
long length = viewArray.length();
|
||||
this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
|
||||
this.s = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
|
||||
|
||||
//Reshape to match the expected shape of the input gradient arrays
|
||||
this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
|
||||
this.s = Shape.newShapeNoCopy(this.s, gradientShape, gradientOrder == 'f');
|
||||
if (m == null || s == null)
|
||||
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
|
||||
|
||||
this.gradientReshapeOrder = gradientOrder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the update based on the given gradient
|
||||
*
|
||||
* @param gradient the gradient to get the update for
|
||||
* @param iteration
|
||||
* @return the gradient
|
||||
*/
|
||||
@Override
|
||||
public void applyUpdater(INDArray gradient, int iteration, int epoch) {
|
||||
if (m == null || s == null)
|
||||
throw new IllegalStateException("Updater has not been initialized with view state");
|
||||
|
||||
double beta1 = config.getBeta1();
|
||||
double beta2 = config.getBeta2();
|
||||
double learningRate = config.getLearningRate(iteration, epoch);
|
||||
double epsilon = config.getEpsilon();
|
||||
|
||||
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaBeliefUpdater(gradient, s, m, learningRate, beta1, beta2, epsilon, iteration));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.learning.config;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.learning.AdaBeliefUpdater;
|
||||
import org.nd4j.linalg.learning.GradientUpdater;
|
||||
import org.nd4j.linalg.schedule.ISchedule;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* AdaBelief
|
||||
* https://arxiv.org/pdf/2010.07468.pdf
|
||||
*/
|
||||
@Data
|
||||
@Builder(builderClassName = "Builder")
|
||||
public class AdaBelief implements IUpdater {
|
||||
|
||||
public static final double DEFAULT_LEARNING_RATE = 1e-3;
|
||||
public static final double DEFAULT_EPSILON = 1e-14;
|
||||
public static final double DEFAULT_BETA1_MEAN_DECAY = 0.9;
|
||||
public static final double DEFAULT_BETA2_VAR_DECAY = 0.999;
|
||||
|
||||
@lombok.Builder.Default private double learningRate = DEFAULT_LEARNING_RATE; // learning rate
|
||||
private ISchedule learningRateSchedule;
|
||||
@lombok.Builder.Default private double beta1 = DEFAULT_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
||||
@lombok.Builder.Default private double beta2 = DEFAULT_BETA2_VAR_DECAY; // gradient sqrt decay rate
|
||||
@lombok.Builder.Default private double epsilon = DEFAULT_EPSILON;
|
||||
|
||||
public AdaBelief() {
|
||||
this(DEFAULT_LEARNING_RATE, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY,
|
||||
DEFAULT_EPSILON);
|
||||
}
|
||||
|
||||
public AdaBelief(double learningRate){
|
||||
this(learningRate, null, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
|
||||
}
|
||||
|
||||
public AdaBelief(ISchedule learningRateSchedule){
|
||||
this(Double.NaN, learningRateSchedule, DEFAULT_BETA1_MEAN_DECAY, DEFAULT_BETA2_VAR_DECAY, DEFAULT_EPSILON);
|
||||
}
|
||||
|
||||
public AdaBelief(double learningRate, double beta1, double beta2, double epsilon) {
|
||||
this(learningRate, null, beta1, beta2, epsilon);
|
||||
}
|
||||
|
||||
private AdaBelief(@JsonProperty("learningRate") double learningRate,
|
||||
@JsonProperty("learningRateSchedule") ISchedule learningRateSchedule,
|
||||
@JsonProperty("beta1") double beta1,
|
||||
@JsonProperty("beta2") double beta2,
|
||||
@JsonProperty("epsilon") double epsilon){
|
||||
this.learningRate = learningRate;
|
||||
this.learningRateSchedule = learningRateSchedule;
|
||||
this.beta1 = beta1;
|
||||
this.beta2 = beta2;
|
||||
this.epsilon = epsilon;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long stateSize(long numParams) {
|
||||
return 2 * numParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GradientUpdater instantiate(INDArray viewArray, boolean initializeViewArray) {
|
||||
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
|
||||
long[] gradientShape = viewArray.shape();
|
||||
gradientShape = Arrays.copyOf(gradientShape, gradientShape.length);
|
||||
gradientShape[1] /= 2;
|
||||
u.setStateViewArray(viewArray, gradientShape, viewArray.ordering(), initializeViewArray);
|
||||
return u;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GradientUpdater instantiate(Map<String, INDArray> updaterState, boolean initializeStateArrays) {
|
||||
AdaBeliefUpdater u = new AdaBeliefUpdater(this);
|
||||
u.setState(updaterState, initializeStateArrays);
|
||||
return u;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdaBelief clone() {
|
||||
return new AdaBelief(learningRate, learningRateSchedule, beta1, beta2, epsilon);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLearningRate(int iteration, int epoch){
|
||||
if(learningRateSchedule != null){
|
||||
return learningRateSchedule.valueAt(iteration, epoch);
|
||||
}
|
||||
return learningRate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasLearningRate() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLrAndSchedule(double lr, ISchedule lrSchedule) {
|
||||
this.learningRate = lr;
|
||||
this.learningRateSchedule = lrSchedule;
|
||||
}
|
||||
|
||||
//Partial builder implementation to give public no-arg constructor
|
||||
public static class Builder {
|
||||
public Builder(){ }
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue