Learning updaters for gradient (#335)
* libnd4j raw implementation of sgd upader Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections and simple test added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections after discussion Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j integrate applyScalar Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j raw implementation of rmsPropUpdater on cpu Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix operations declaration Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j rmsPropUpdater added, test cases for sgd, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some fixes and improvements for rmsPropUpdater based on Java tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed cuda implementation, update tests and corrected behavior according java tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j adaGrad updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one minor fix for ada grad Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several more fixes for ada_grad Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j nesterovs updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed nesterovs updater behavior, several typos and rename file Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one minor typo Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j ada max updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos in adaMax updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos in adaMaxUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for adaMax, added Adam Updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j adaDeltaUpdater added, minor fixes for adamUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for adaDeltaUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j nadamUpdater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more correction for nadam updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for nadam updater and added amsGradUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several typos fixed in amsGradUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections and added f order support rmsProp updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added support of f order for all updaters and modify tests for testing in place Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed issues for updates when not in place mode used, added tests for f order Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added input shape checks Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections for different cases handling Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some code clean up and optimize per request Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j updaters refactoring after review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * SgdUpdater wrapper Signed-off-by: raver119 <raver119@gmail.com> * first test Signed-off-by: raver119 <raver119@gmail.com> * RmsPropUpdater added Signed-off-by: raver119 <raver119@gmail.com> * NadamUpdater + NesterovsUpdater Signed-off-by: raver119 <raver119@gmail.com> * AmsGradUpdater Signed-off-by: raver119 <raver119@gmail.com> * AdamUpdater added Signed-off-by: raver119 <raver119@gmail.com> * AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater Signed-off-by: raver119 <raver119@gmail.com> * AdaGradUpdater test added Signed-off-by: raver119 <raver119@gmail.com> * libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j next step to split operations implementation into separate files Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master and minor corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j revert some changes of split implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j forgot to add header file Signed-off-by: Oleg <oleg.semeniv@gmail.com> * public default constructors Signed-off-by: raver119 <raver119@gmail.com> * ImportClassMapping updated Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
015147b713
commit
69c92ca5ae
|
@ -45,6 +45,7 @@
|
||||||
#include <ops/declarable/headers/util.h>
|
#include <ops/declarable/headers/util.h>
|
||||||
#include <ops/declarable/headers/BarnesHutTsne.h>
|
#include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
#include <ops/declarable/headers/images.h>
|
#include <ops/declarable/headers/images.h>
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
#include <system/dll.h>
|
#include <system/dll.h>
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateMsg = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateMsdx = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateMsg = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateMsdx = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateMsg->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateMsdx->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dRho, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
const auto rho = INPUT_VARIABLE(3);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(4);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dRho = rho->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dRho = T_ARG(0);
|
||||||
|
dEpsilon = T_ARG(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(ada_delta_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initState = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateH = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
if (gradient->isEmpty() || initState->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 2) {
|
||||||
|
const auto lr = INPUT_VARIABLE(2);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dEpsilon = T_ARG(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(ada_grad_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateU = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateM = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateU = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateM = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
// todo maybe we need an error like on Java side
|
||||||
|
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
|
||||||
|
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
const auto lr = INPUT_VARIABLE(3);
|
||||||
|
const auto beta1 = INPUT_VARIABLE(4);
|
||||||
|
const auto beta2 = INPUT_VARIABLE(5);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(6);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||||
|
REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dBeta1 = beta1->e<double>(0);
|
||||||
|
dBeta2 = beta2->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dBeta1 = T_ARG(1);
|
||||||
|
dBeta2 = T_ARG(2);
|
||||||
|
dEpsilon = T_ARG(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(ada_max_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateU = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateM = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateU = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateM = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
// todo maybe we need an error like on Java side
|
||||||
|
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
const auto lr = INPUT_VARIABLE(3);
|
||||||
|
const auto beta1 = INPUT_VARIABLE(4);
|
||||||
|
const auto beta2 = INPUT_VARIABLE(5);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(6);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||||
|
REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dBeta1 = beta1->e<double>(0);
|
||||||
|
dBeta2 = beta2->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dBeta1 = T_ARG(1);
|
||||||
|
dBeta2 = T_ARG(2);
|
||||||
|
dEpsilon = T_ARG(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(adam_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateV = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateM = INPUT_VARIABLE(2);
|
||||||
|
const auto initStateH = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateV = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateM = OUTPUT_VARIABLE(2);
|
||||||
|
auto stateH = OUTPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
// todo maybe we need an error like on Java side
|
||||||
|
if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateH->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 4) {
|
||||||
|
const auto lr = INPUT_VARIABLE(4);
|
||||||
|
const auto beta1 = INPUT_VARIABLE(5);
|
||||||
|
const auto beta2 = INPUT_VARIABLE(6);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(7);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||||
|
REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dBeta1 = beta1->e<double>(0);
|
||||||
|
dBeta2 = beta2->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dBeta1 = T_ARG(1);
|
||||||
|
dBeta2 = T_ARG(2);
|
||||||
|
dEpsilon = T_ARG(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH,
|
||||||
|
*update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(ams_grad_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initStateV = INPUT_VARIABLE(1);
|
||||||
|
const auto initStateM = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateV = OUTPUT_VARIABLE(1);
|
||||||
|
auto stateM = OUTPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
// todo maybe we need an error like on Java side
|
||||||
|
if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dBeta1, dBeta2, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
const auto lr = INPUT_VARIABLE(3);
|
||||||
|
const auto beta1 = INPUT_VARIABLE(4);
|
||||||
|
const auto beta2 = INPUT_VARIABLE(5);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(6);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
|
||||||
|
REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dBeta1 = beta1->e<double>(0);
|
||||||
|
dBeta2 = beta2->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dBeta1 = T_ARG(1);
|
||||||
|
dBeta2 = T_ARG(2);
|
||||||
|
dEpsilon = T_ARG(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(nadam_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initState = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateV = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
if (gradient->isEmpty() || initState->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!");
|
||||||
|
|
||||||
|
double dLr, dMomentum;
|
||||||
|
|
||||||
|
if (block.width() > 2) {
|
||||||
|
const auto lr = INPUT_VARIABLE(2);
|
||||||
|
const auto momentum = INPUT_VARIABLE(3);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dMomentum = momentum->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dMomentum = T_ARG(1);
|
||||||
|
}
|
||||||
|
helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(nesterovs_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto gradient = INPUT_VARIABLE(0);
|
||||||
|
const auto initState = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto update = OUTPUT_VARIABLE(0);
|
||||||
|
auto stateG = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
if (gradient->isEmpty() || initState->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient,"
|
||||||
|
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
|
||||||
|
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
|
||||||
|
|
||||||
|
bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!");
|
||||||
|
|
||||||
|
double dLr, dRmsDecay, dEpsilon;
|
||||||
|
|
||||||
|
if (block.width() > 2) {
|
||||||
|
const auto lr = INPUT_VARIABLE(2);
|
||||||
|
const auto rmsDecay = INPUT_VARIABLE(3);
|
||||||
|
const auto epsilon = INPUT_VARIABLE(4);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf());
|
||||||
|
REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
|
||||||
|
|
||||||
|
dLr = lr->e<double>(0);
|
||||||
|
dRmsDecay = rmsDecay->e<double>(0);
|
||||||
|
dEpsilon = epsilon->e<double>(0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dLr = T_ARG(0);
|
||||||
|
dRmsDecay = T_ARG(1);
|
||||||
|
dEpsilon = T_ARG(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(rms_prop_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/updaters.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
|
const auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (input->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!");
|
||||||
|
|
||||||
|
if (block.width() > 1) {
|
||||||
|
const auto lr = INPUT_VARIABLE(1);
|
||||||
|
REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
|
||||||
|
|
||||||
|
input->applyScalarArr(scalar::Multiply, *lr, *output);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
input->applyScalar(scalar::Multiply, T_ARG(0), *output);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(sgd_updater) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,210 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef LIBND4J_HEADERS_UPDATERS_H
|
||||||
|
#define LIBND4J_HEADERS_UPDATERS_H
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/common.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SGD updater
|
||||||
|
* Input arrays:
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* Optional:
|
||||||
|
* 1 - scalar learning rate value
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_sgd_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RmsPropUpdater updater
|
||||||
|
* Input arrays:
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - Initial state
|
||||||
|
* Optional:
|
||||||
|
* 2 - scalar learning rate value
|
||||||
|
* 3 - scalar rms decay
|
||||||
|
* 4 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - scalar rms decay
|
||||||
|
* 2 - epsilon
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_rms_prop_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// AdaGrad
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - historical grad state
|
||||||
|
* Optional :
|
||||||
|
* 2 - scalar learning rate value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - epsilon
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_ada_grad_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// AdaMax
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V
|
||||||
|
* 2 - gradient state M
|
||||||
|
* Optional :
|
||||||
|
* 3 - scalar learning rate value
|
||||||
|
* 4 - beta 1 value
|
||||||
|
* 5 - beta 2 value
|
||||||
|
* 6 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - beta 1 value
|
||||||
|
* 2 - beta 2 value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* I args
|
||||||
|
* 0 - iteration
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_ada_max_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// Nesterov's momentum
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - V grad state
|
||||||
|
* Optional :
|
||||||
|
* 2 - scalar learning rate value
|
||||||
|
* 3 - scalar momentum value
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - learning rate value
|
||||||
|
* 1 - momentum value
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_nesterovs_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// Adam
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V
|
||||||
|
* 2 - gradient state M
|
||||||
|
* Optional :
|
||||||
|
* 3 - scalar learning rate value
|
||||||
|
* 4 - beta 1 value
|
||||||
|
* 5 - beta 2 value
|
||||||
|
* 6 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - beta 1 value
|
||||||
|
* 2 - beta 2 value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* I args
|
||||||
|
* 0 - iteration
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_adam_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// AdaDelta
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V
|
||||||
|
* 2 - gradient state M
|
||||||
|
* Optional :
|
||||||
|
* 3 - rho value
|
||||||
|
* 6 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - rho
|
||||||
|
* 1 - epsilon
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_ada_delta_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// Nadam
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V
|
||||||
|
* 2 - gradient state M
|
||||||
|
* Optional :
|
||||||
|
* 3 - scalar learning rate value
|
||||||
|
* 4 - beta 1 value
|
||||||
|
* 5 - beta 2 value
|
||||||
|
* 6 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - beta 1 value
|
||||||
|
* 2 - beta 2 value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* I args
|
||||||
|
* 0 - iteration
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_nadam_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
// AmsGrad
|
||||||
|
/* Input arrays :
|
||||||
|
* 0 - input array with gradients.
|
||||||
|
* 1 - gradient state V - sqrd gradients
|
||||||
|
* 2 - gradient state M - moving avg
|
||||||
|
* 3 - gradient state H - max
|
||||||
|
* Optional :
|
||||||
|
* 4 - scalar learning rate value
|
||||||
|
* 5 - beta 1 value
|
||||||
|
* 6 - beta 2 value
|
||||||
|
* 7 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* T args
|
||||||
|
* 0 - scalar learning rate value
|
||||||
|
* 1 - beta 1 value
|
||||||
|
* 2 - beta 2 value
|
||||||
|
* 3 - epsilon
|
||||||
|
* Optional:
|
||||||
|
* I args
|
||||||
|
* 0 - iteration
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_ams_grad_updater)
|
||||||
|
DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,108 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
|
||||||
|
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initMsg = initStateMsg.bufferAsT<T>();
|
||||||
|
const T* initMsdx = initStateMsdx.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stMsg = stateMsg.bufferAsT<T>();
|
||||||
|
T* stMsdx = stateMsdx.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T rho = static_cast<T>(dRho);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T rhoT = (1 - rho);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateMsdx.ordering() &&
|
||||||
|
stateMsdx.ordering() == initStateMsdx.ordering() &&
|
||||||
|
stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT;
|
||||||
|
|
||||||
|
up[i] = grad[i] * (sd::math::nd4j_sqrt<T, T>(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[i] + epsilon));
|
||||||
|
|
||||||
|
stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsg.getShapeInfo());
|
||||||
|
bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsg.getShapeInfo());
|
||||||
|
bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsdx.getShapeInfo());
|
||||||
|
bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsdx.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < gradient.lengthOf(); i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.getShapeInfo(), coords);
|
||||||
|
const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.getShapeInfo(), coords);
|
||||||
|
const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.getShapeInfo(), coords);
|
||||||
|
const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
|
||||||
|
stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT;
|
||||||
|
|
||||||
|
up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt<T, T>(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[stMsgOffset] + epsilon));
|
||||||
|
|
||||||
|
stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
|
||||||
|
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* init = initState.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* st = stateH.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
st[i] = init[i] + grad[i] * grad[i];
|
||||||
|
up[i] = (lr * grad[i]) / (math::nd4j_sqrt<T, T>(st[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
|
||||||
|
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
|
||||||
|
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset];
|
||||||
|
up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH,
|
||||||
|
const double dLr, const double dEpsilon) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initU = initStateU.bufferAsT<T>();
|
||||||
|
const T* initM = initStateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stU = stateU.bufferAsT<T>();
|
||||||
|
T* stM = stateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
T epsilonT = lr / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
//m = B_1 * m + (1-B_1)*grad
|
||||||
|
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
|
||||||
|
//u = max(B_2 * u, |grad|)
|
||||||
|
stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32;
|
||||||
|
|
||||||
|
up[i] = stM[i] * epsilonT / stU[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo());
|
||||||
|
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo());
|
||||||
|
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
|
||||||
|
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords);
|
||||||
|
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords);
|
||||||
|
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
|
||||||
|
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
//m = B_1 * m + (1-B_1)*grad
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
//u = max(B_2 * u, |grad|)
|
||||||
|
stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32;
|
||||||
|
|
||||||
|
up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,113 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update,
|
||||||
|
NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||||
|
const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initU = initStateU.bufferAsT<T>();
|
||||||
|
const T* initM = initStateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stU = stateU.bufferAsT<T>();
|
||||||
|
T* stM = stateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
const T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||||
|
|
||||||
|
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateU.ordering() &&
|
||||||
|
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
|
||||||
|
stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2);
|
||||||
|
|
||||||
|
up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo());
|
||||||
|
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo());
|
||||||
|
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
|
||||||
|
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords);
|
||||||
|
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords);
|
||||||
|
const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
|
||||||
|
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2);
|
||||||
|
|
||||||
|
up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initV = initStateV.bufferAsT<T>();
|
||||||
|
const T* initM = initStateM.bufferAsT<T>();
|
||||||
|
const T* initH = initStateH.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stV = stateV.bufferAsT<T>();
|
||||||
|
T* stM = stateM.bufferAsT<T>();
|
||||||
|
T* stH = stateH.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1.0 - sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1)));
|
||||||
|
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
const T mbeta1 = (1 - beta1);
|
||||||
|
const T mbeta2 = (1 - beta2);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() &&
|
||||||
|
1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateV.ordering() &&
|
||||||
|
stateV.ordering() == initStateV.ordering() &&
|
||||||
|
stateV.ordering() == initStateM.ordering() &&
|
||||||
|
stateM.ordering() == initStateM.ordering() &&
|
||||||
|
stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
stM[i] = beta1 * initM[i] + grad[i] * mbeta1;
|
||||||
|
stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2;
|
||||||
|
stH[i] = sd::math::nd4j_max(initH[i], stV[i]);
|
||||||
|
|
||||||
|
up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt<T, T>(stH[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo());
|
||||||
|
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
|
||||||
|
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
|
||||||
|
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
|
||||||
|
bool bXInHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateH.getShapeInfo());
|
||||||
|
bool bXStHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords);
|
||||||
|
const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
|
||||||
|
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
|
||||||
|
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
|
||||||
|
const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.getShapeInfo(), coords);
|
||||||
|
const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1;
|
||||||
|
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
|
||||||
|
stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]);
|
||||||
|
|
||||||
|
up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt<T, T>(stH[stHOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,116 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1,
|
||||||
|
const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* initV = initStateV.bufferAsT<T>();
|
||||||
|
const T* initM = initStateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* stV = stateV.bufferAsT<T>();
|
||||||
|
T* stM = stateM.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
const T mbeta1T = 1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
const T mbeta1 = (1 - beta1);
|
||||||
|
const T mbeta2 = (1 - beta2);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() &&
|
||||||
|
update.ordering() == stateV.ordering() &&
|
||||||
|
stateV.ordering() == initStateV.ordering() &&
|
||||||
|
stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto oneMinusBeta1Grad = grad[i] * mbeta1;
|
||||||
|
|
||||||
|
stM[i] = beta1 * initM[i] + oneMinusBeta1Grad;
|
||||||
|
stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2;
|
||||||
|
|
||||||
|
up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo());
|
||||||
|
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
|
||||||
|
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
|
||||||
|
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords);
|
||||||
|
const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
|
||||||
|
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
|
||||||
|
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
auto oneMinusBeta1Grad = grad[xOffset] * mbeta1;
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad;
|
||||||
|
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
|
||||||
|
|
||||||
|
up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[stVOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* init = initState.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* st = stateV.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T momentum = static_cast<T>(dMomentum);
|
||||||
|
const T momentumT = (-momentum - 1);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
T prevState = momentum * init[i];
|
||||||
|
st[i] = prevState - lr * grad[i];
|
||||||
|
up[i] = prevState + momentumT * st[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
|
||||||
|
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
|
||||||
|
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
T prevState = momentum * init[initOffset];
|
||||||
|
st[stOffset] = prevState - lr * grad[xOffset];
|
||||||
|
up[zOffset] = prevState + momentumT * st[stOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
|
||||||
|
const double dLr, const double dRmsDecay, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T* grad = gradient.bufferAsT<T>();
|
||||||
|
const T* init = initState.bufferAsT<T>();
|
||||||
|
|
||||||
|
T* up = update.bufferAsT<T>();
|
||||||
|
T* st = stateG.bufferAsT<T>();
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T rmsDecay = static_cast<T>(dRmsDecay);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
|
||||||
|
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews();
|
||||||
|
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering();
|
||||||
|
|
||||||
|
if (bEws1 && bSameOrdering) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ;
|
||||||
|
up[i] = (lr * grad[i]) / ( math::nd4j_sqrt<T, T>(st[i]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
|
||||||
|
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
|
||||||
|
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateG.getShapeInfo());
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
|
||||||
|
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
|
||||||
|
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
|
||||||
|
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ;
|
||||||
|
up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
|
||||||
|
const double dLr, const double dRmsDecay, const double dEpsilon) {
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,129 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo,
|
||||||
|
const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg,
|
||||||
|
const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initMsg= reinterpret_cast<const T*>(vinMsg);
|
||||||
|
const auto initMsdx = reinterpret_cast<const T*>(vinMsdx);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stMsg = reinterpret_cast<T*>(vstMsg);
|
||||||
|
auto stMsdx = reinterpret_cast<T*>(vstMsdx);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T rhoT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
rhoT = (1 - rho);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) &&
|
||||||
|
shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) &&
|
||||||
|
shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo);
|
||||||
|
bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo);
|
||||||
|
bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo);
|
||||||
|
bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering){
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords);
|
||||||
|
stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords);
|
||||||
|
initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords);
|
||||||
|
stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT;
|
||||||
|
|
||||||
|
up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt<T, T>(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[stMsgOffset] + epsilon));
|
||||||
|
|
||||||
|
stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo,
|
||||||
|
void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T rho = static_cast<T>(dRho);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
|
||||||
|
adaDeltaUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinMsg, inMsgShapeInfo,
|
||||||
|
vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
|
||||||
|
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "adaDeltaUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initStateMsg.getSpecialBuffer(), initStateMsg.getSpecialShapeInfo(), initStateMsdx.getSpecialBuffer(), initStateMsdx.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(),stateMsg.getSpecialBuffer(), stateMsg.getSpecialShapeInfo(),
|
||||||
|
stateMsdx.getSpecialBuffer(), stateMsdx.getSpecialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
||||||
|
const T lr, const T epsilon) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto init = reinterpret_cast<const T*>(vin);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto st = reinterpret_cast<T*>(vst);
|
||||||
|
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(inShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
|
||||||
|
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering) {
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
|
||||||
|
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset];
|
||||||
|
up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
||||||
|
const double dLr, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
|
||||||
|
adaGradUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vin, inShapeInfo,
|
||||||
|
vz, zShapeInfo, vst, stShapeInfo, lr, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState,
|
||||||
|
NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "adaGradUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(),
|
||||||
|
gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
|
||||||
|
stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,142 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
|
||||||
|
const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||||
|
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initU = reinterpret_cast<const T*>(vinv);
|
||||||
|
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stU = reinterpret_cast<T*>(vstV);
|
||||||
|
auto stM = reinterpret_cast<T*>(vstM);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T beta1T, epsilonT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
beta1T = sd::math::nd4j_pow<T,T,T>(beta1, (iteration + 1) );
|
||||||
|
|
||||||
|
epsilonT = lr / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(stvShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||||
|
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||||
|
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||||
|
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering) {
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||||
|
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||||
|
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||||
|
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
//m = B_1 * m + (1-B_1)*grad
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
//u = max(B_2 * u, |grad|)
|
||||||
|
stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32;
|
||||||
|
|
||||||
|
up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
|
||||||
|
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr,
|
||||||
|
const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
adaMaxUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz,
|
||||||
|
zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1,
|
||||||
|
const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "adaMaxUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(),
|
||||||
|
gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), initStateU.getSpecialBuffer(),
|
||||||
|
initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(),
|
||||||
|
stateU.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(),
|
||||||
|
dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,139 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
|
||||||
|
const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV,
|
||||||
|
const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||||
|
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initU = reinterpret_cast<const T*>(vinv);
|
||||||
|
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stU = reinterpret_cast<T*>(vstV);
|
||||||
|
auto stM = reinterpret_cast<T*>(vstM);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T epsilonT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
|
||||||
|
|
||||||
|
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
|
||||||
|
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
|
||||||
|
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||||
|
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||||
|
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||||
|
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering){
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||||
|
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||||
|
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||||
|
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
|
||||||
|
stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2);
|
||||||
|
|
||||||
|
up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
|
||||||
|
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
adamUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
|
||||||
|
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
|
||||||
|
const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "adamUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initStateU.getSpecialBuffer(), initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), stateU.getSpecialShapeInfo(),
|
||||||
|
stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
|
||||||
|
const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM,
|
||||||
|
const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo,
|
||||||
|
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initV = reinterpret_cast<const T*>(vinv);
|
||||||
|
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||||
|
const auto initH = reinterpret_cast<const T*>(vinh);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stV = reinterpret_cast<T*>(vstV);
|
||||||
|
auto stM = reinterpret_cast<T*>(vstM);
|
||||||
|
auto stH = reinterpret_cast<T*>(vstH);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T mbeta1, mbeta2, epsilonT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1.0 - sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1)));
|
||||||
|
|
||||||
|
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
|
||||||
|
epsilonT = epsilon;
|
||||||
|
|
||||||
|
mbeta1 = (1 - beta1);
|
||||||
|
mbeta2 = (1 - beta2);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo);
|
||||||
|
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
|
||||||
|
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
|
||||||
|
shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) &&
|
||||||
|
shape::order(sthShapeInfo) == shape::order(inhShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||||
|
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||||
|
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||||
|
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||||
|
bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo);
|
||||||
|
bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering){
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||||
|
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||||
|
initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||||
|
stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||||
|
initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords);
|
||||||
|
stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1;
|
||||||
|
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
|
||||||
|
stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]);
|
||||||
|
|
||||||
|
up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt<T, T>(stH[stHOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||||
|
const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||||
|
void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
amsGradUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
|
||||||
|
vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "amsGradUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
|
||||||
|
initStateH.getSpecialBuffer(), initStateH.getSpecialShapeInfo(), update.getSpecialBuffer(), update.getSpecialShapeInfo(),
|
||||||
|
stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(),
|
||||||
|
stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,137 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
|
||||||
|
const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
|
||||||
|
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto initV = reinterpret_cast<const T*>(vinv);
|
||||||
|
const auto initM = reinterpret_cast<const T*>(vinm);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto stV = reinterpret_cast<T*>(vstV);
|
||||||
|
auto stM = reinterpret_cast<T*>(vstM);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T mbeta1T, mbeta1, mbeta2;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
mbeta1T = 1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
|
||||||
|
mbeta1 = (1 - beta1);
|
||||||
|
mbeta2 = (1 - beta2);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
|
||||||
|
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
|
||||||
|
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
|
||||||
|
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
|
||||||
|
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
|
||||||
|
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering){
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
|
||||||
|
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
|
||||||
|
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
|
||||||
|
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto oneMinusBeta1Grad = grad[xOffset] * mbeta1;
|
||||||
|
|
||||||
|
stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad;
|
||||||
|
stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
|
||||||
|
|
||||||
|
up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[stUOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM,
|
||||||
|
const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T beta1 = static_cast<T>(dBeta1);
|
||||||
|
const T beta2 = static_cast<T>(dBeta2);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
const T iteration = static_cast<T>(nIteration);
|
||||||
|
|
||||||
|
nadamUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
|
||||||
|
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
|
||||||
|
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "nadamUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(),
|
||||||
|
stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,117 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) {
|
||||||
|
|
||||||
|
const auto grad = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto init = reinterpret_cast<const T*>(vin);
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto st = reinterpret_cast<T*>(vst);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ T momentumT;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
momentumT = (-momentum - 1);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
|
||||||
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(stShapeInfo);
|
||||||
|
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
|
||||||
|
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering) {
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
|
||||||
|
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
T prevState = momentum * init[initOffset];
|
||||||
|
st[stOffset] = prevState - lr * grad[xOffset];
|
||||||
|
up[zOffset] = prevState + momentumT * st[stOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
||||||
|
const double dLr, const double dMomentum) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T momentum = static_cast<T>(dMomentum);
|
||||||
|
nesterovsUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vin, inShapeInfo,
|
||||||
|
vz, zShapeInfo, vst, stShapeInfo, lr, momentum);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState,
|
||||||
|
NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "nesterovsUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState });
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock,
|
||||||
|
context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
|
||||||
|
stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), dLr, dMomentum), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState });
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
#include <math/platformmath.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
||||||
|
const T lr, const T rmsDecay, const T epsilon) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto init = reinterpret_cast<const T*>(vin);
|
||||||
|
|
||||||
|
auto up = reinterpret_cast<T*>(vz);
|
||||||
|
auto st = reinterpret_cast<T*>(vst);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLen;
|
||||||
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
||||||
|
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
|
||||||
|
|
||||||
|
bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) &&
|
||||||
|
shape::order(xShapeInfo) == shape::order(inShapeInfo);
|
||||||
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
|
||||||
|
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
|
||||||
|
|
||||||
|
if (!bEWS || !bOrdering) {
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, coords);
|
||||||
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
||||||
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
||||||
|
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
|
||||||
|
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
|
||||||
|
}
|
||||||
|
|
||||||
|
st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ;
|
||||||
|
up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||||
|
const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
||||||
|
const double dLr, const double dRmsDecay, const double dEpsilon) {
|
||||||
|
|
||||||
|
const T lr = static_cast<T>(dLr);
|
||||||
|
const T rmsDecay = static_cast<T>(dRmsDecay);
|
||||||
|
const T epsilon = static_cast<T>(dEpsilon);
|
||||||
|
|
||||||
|
rmsPropUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vin, inShapeInfo,
|
||||||
|
vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
|
||||||
|
const double dLr, const double dRmsDecay, const double dEpsilon) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "rmsPropUpdater");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState });
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock,
|
||||||
|
context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
|
||||||
|
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
|
||||||
|
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
|
||||||
|
stateG.getSpecialBuffer(), stateG.getSpecialShapeInfo(),
|
||||||
|
dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019-2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_UPDATER_RMS_PROM_H
|
||||||
|
#define LIBND4J_UPDATER_RMS_PROM_H
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon);
|
||||||
|
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon);
|
||||||
|
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum);
|
||||||
|
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
|
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
|
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
|
||||||
|
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
|
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
File diff suppressed because it is too large
Load Diff
|
@ -43,6 +43,15 @@ public class ImportClassMapping {
|
||||||
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
|
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
|
||||||
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
|
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
|
||||||
org.nd4j.linalg.api.ops.NoOp.class,
|
org.nd4j.linalg.api.ops.NoOp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class AdaDeltaUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdaDeltaUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, double rho, double epsilon) {
|
||||||
|
this(gradients, stateMsg, stateMsdx, gradients, stateMsg, stateMsdx, rho, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, @NonNull INDArray updates, @NonNull INDArray updatedStateMsg, @NonNull INDArray updatedStateMsdx, double rho, double epsilon) {
|
||||||
|
addInputArgument(gradients, stateMsg, stateMsdx);
|
||||||
|
addOutputArgument(updates, updatedStateMsg, updatedStateMsdx);
|
||||||
|
addTArgument(rho, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "ada_delta_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class AdaGradUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdaGradUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double epsilon) {
|
||||||
|
this(gradients, state, gradients, state, lr, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double epsilon) {
|
||||||
|
addInputArgument(gradients, state);
|
||||||
|
addOutputArgument(updates, updatedState);
|
||||||
|
addTArgument(lr, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "ada_grad_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class AdaMaxUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdaMaxUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
addInputArgument(gradients, stateU, stateM);
|
||||||
|
addOutputArgument(updates, updatedStateU, updatedStateM);
|
||||||
|
addTArgument(lr, beta1, beta2, epsilon);
|
||||||
|
addIArgument(iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "ada_max_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class AdamUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AdamUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
addInputArgument(gradients, stateU, stateM);
|
||||||
|
addOutputArgument(updates, updatedStateU, updatedStateM);
|
||||||
|
addTArgument(lr, beta1, beta2, epsilon);
|
||||||
|
addIArgument(iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "adam_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class AmsGradUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public AmsGradUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
addInputArgument(gradients, stateV, stateM, stateH);
|
||||||
|
addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH);
|
||||||
|
addTArgument(lr, beta1, beta2, epsilon);
|
||||||
|
addIArgument(iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "ams_grad_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class NadamUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public NadamUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
|
||||||
|
addInputArgument(gradients, stateV, stateM);
|
||||||
|
addOutputArgument(updates, updatedStateV, updatedStateM);
|
||||||
|
addTArgument(lr, beta1, beta2, epsilon);
|
||||||
|
addIArgument(iteration);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "nadam_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class NesterovsUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public NesterovsUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double momentum) {
|
||||||
|
this(gradients, state, gradients, state, lr, momentum);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double momentum) {
|
||||||
|
addInputArgument(gradients, state);
|
||||||
|
addOutputArgument(updates, updatedState);
|
||||||
|
addTArgument(lr, momentum);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "nesterovs_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class RmsPropUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RmsPropUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double decay, double epsilon) {
|
||||||
|
this(gradients, state, gradients, state, lr, decay, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double decay, double epsilon) {
|
||||||
|
addInputArgument(gradients, state);
|
||||||
|
addOutputArgument(updates, updatedState);
|
||||||
|
addTArgument(lr, decay, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "rms_prop_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.updaters;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author raver119@gmail.com
|
||||||
|
*/
|
||||||
|
public class SgdUpdater extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public SgdUpdater() {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
|
||||||
|
public SgdUpdater(@NonNull INDArray input, double lr) {
|
||||||
|
this(input, input, lr);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SgdUpdater(@NonNull INDArray input, @NonNull INDArray output, double lr) {
|
||||||
|
addInputArgument(input);
|
||||||
|
addOutputArgument(output);
|
||||||
|
addTArgument(lr);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "sgd_updater";
|
||||||
|
}
|
||||||
|
}
|
|
@ -10686,6 +10686,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
// #include <ops/declarable/headers/util.h>
|
// #include <ops/declarable/headers/util.h>
|
||||||
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
// #include <ops/declarable/headers/images.h>
|
// #include <ops/declarable/headers/images.h>
|
||||||
|
// #include <ops/declarable/headers/updaters.h>
|
||||||
// #include <system/dll.h>
|
// #include <system/dll.h>
|
||||||
// #include <helpers/shape.h>
|
// #include <helpers/shape.h>
|
||||||
// #include <helpers/TAD.h>
|
// #include <helpers/TAD.h>
|
||||||
|
|
|
@ -12422,6 +12422,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #include <ops/declarable/headers/util.h>
|
// #include <ops/declarable/headers/util.h>
|
||||||
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
// #include <ops/declarable/headers/BarnesHutTsne.h>
|
||||||
// #include <ops/declarable/headers/images.h>
|
// #include <ops/declarable/headers/images.h>
|
||||||
|
// #include <ops/declarable/headers/updaters.h>
|
||||||
// #include <system/dll.h>
|
// #include <system/dll.h>
|
||||||
// #include <helpers/shape.h>
|
// #include <helpers/shape.h>
|
||||||
// #include <helpers/TAD.h>
|
// #include <helpers/TAD.h>
|
||||||
|
|
|
@ -15,10 +15,12 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4j.linalg.learning;
|
package org.nd4j.linalg.learning;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.learning.config.*;
|
import org.nd4j.linalg.learning.config.*;
|
||||||
|
@ -58,14 +60,23 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val msgu = msg.dup();
|
||||||
|
val msdxu = msdx.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon);
|
UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(g3, msgu, msdxu, rho, epsilon));
|
||||||
|
|
||||||
assertEquals(msg, state.get("msg"));
|
assertEquals(msg, state.get("msg"));
|
||||||
assertEquals(msdx, state.get("msdx"));
|
assertEquals(msdx, state.get("msdx"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(msg, msgu);
|
||||||
|
assertEquals(msdx, msdxu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,13 +96,20 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val su = s.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
|
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(g3, su, lr, epsilon));
|
||||||
|
|
||||||
assertEquals(s, state.get("grad"));
|
assertEquals(s, state.get("grad"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(s, su);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,14 +136,23 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val mu = m.dup();
|
||||||
|
val vu = v.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
|
||||||
|
|
||||||
assertEquals(m, state.get("M"));
|
assertEquals(m, state.get("M"));
|
||||||
assertEquals(v, state.get("V"));
|
assertEquals(v, state.get("V"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(m, mu);
|
||||||
|
assertEquals(v, vu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,14 +177,23 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val mu = m.dup();
|
||||||
|
val vu = v.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
|
||||||
|
|
||||||
assertEquals(m, state.get("M"));
|
assertEquals(m, state.get("M"));
|
||||||
assertEquals(v, state.get("V"));
|
assertEquals(v, state.get("V"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(m, mu);
|
||||||
|
assertEquals(v, vu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,15 +221,26 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val mu = m.dup();
|
||||||
|
val vu = v.dup();
|
||||||
|
val hu = vH.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i);
|
UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new AmsGradUpdater(g3, vu, mu, hu, lr, beta1, beta2, eps, i));
|
||||||
|
|
||||||
assertEquals(m, state.get("M"));
|
assertEquals(m, state.get("M"));
|
||||||
assertEquals(v, state.get("V"));
|
assertEquals(v, state.get("V"));
|
||||||
assertEquals(vH, state.get("V_HAT"));
|
assertEquals(vH, state.get("V_HAT"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(m, mu);
|
||||||
|
assertEquals(v, vu);
|
||||||
|
assertEquals(vH, hu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,14 +266,23 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val vu = v.dup();
|
||||||
|
val mu = m.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
|
||||||
|
|
||||||
assertEquals(m, state.get("M"));
|
assertEquals(m, state.get("M"));
|
||||||
assertEquals(v, state.get("V"));
|
assertEquals(v, state.get("V"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(m, mu);
|
||||||
|
assertEquals(v, vu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,13 +303,18 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val vu = v.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum);
|
UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(g3, vu, lr, momentum));
|
||||||
|
|
||||||
assertEquals(v, state.get("V"));
|
assertEquals(v, state.get("V"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(v, vu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,13 +336,19 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
val gu = g.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps);
|
UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps);
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(g3, gu, lr,decay, eps));
|
||||||
|
|
||||||
assertEquals(g, state.get("G"));
|
assertEquals(g, state.get("G"));
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
|
||||||
|
assertEquals(g, gu);
|
||||||
|
assertEquals(g1, g3);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,11 +361,14 @@ public class UpdaterValidation extends BaseNd4jTest {
|
||||||
for( int i=0; i<3; i++ ) {
|
for( int i=0; i<3; i++ ) {
|
||||||
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
|
||||||
INDArray g2 = g1.dup();
|
INDArray g2 = g1.dup();
|
||||||
|
val g3 = g1.dup();
|
||||||
|
|
||||||
UpdaterJavaCode.applySgd(g1, lr);
|
UpdaterJavaCode.applySgd(g1, lr);
|
||||||
|
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(g3, lr));
|
||||||
|
|
||||||
u.applyUpdater(g2, i, 0);
|
u.applyUpdater(g2, i, 0);
|
||||||
assertEquals(g1, g2);
|
assertEquals(g1, g2);
|
||||||
|
assertEquals(g1, g3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue