Learning updaters for gradient (#335)

* libnd4j raw implementation of sgd upader

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections and simple test added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections after discussion

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j integrate applyScalar

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j raw implementation of rmsPropUpdater on cpu

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fix operations declaration

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j rmsPropUpdater added, test cases for sgd, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some fixes and improvements for rmsPropUpdater based on Java tests

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed cuda implementation, update tests and corrected behavior according java tests

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j adaGrad updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one minor fix for ada grad

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several more fixes for ada_grad

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j nesterovs updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed nesterovs updater behavior, several typos and rename file

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one minor typo

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j ada max updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos in adaMax updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos in adaMaxUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for adaMax, added Adam Updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j adaDeltaUpdater added, minor fixes for adamUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for adaDeltaUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j nadamUpdater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one more correction for nadam updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for nadam updater and added amsGradUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several typos fixed in amsGradUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections and added f order support rmsProp updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added support of f order for all updaters and modify tests for testing in place

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed issues for updates when not in place mode used, added tests for f order

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added input shape checks

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections for different cases handling

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some code clean up and optimize per request

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j updaters refactoring after review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* SgdUpdater wrapper

Signed-off-by: raver119 <raver119@gmail.com>

* first test

Signed-off-by: raver119 <raver119@gmail.com>

* RmsPropUpdater added

Signed-off-by: raver119 <raver119@gmail.com>

* NadamUpdater + NesterovsUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AmsGradUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AdamUpdater added

Signed-off-by: raver119 <raver119@gmail.com>

* AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AdaGradUpdater test added

Signed-off-by: raver119 <raver119@gmail.com>

* libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j next step to split operations implementation into separate files

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master and minor corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j revert some changes of split implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j forgot to add header file

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* public default constructors

Signed-off-by: raver119 <raver119@gmail.com>

* ImportClassMapping updated

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-03-23 06:28:31 +02:00 committed by GitHub
parent 015147b713
commit 69c92ca5ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 4646 additions and 2 deletions

View File

@ -45,6 +45,7 @@
#include <ops/declarable/headers/util.h>
#include <ops/declarable/headers/BarnesHutTsne.h>
#include <ops/declarable/headers/images.h>
#include <ops/declarable/headers/updaters.h>
#include <system/dll.h>
#include <helpers/shape.h>
#include <helpers/TAD.h>

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateMsg = INPUT_VARIABLE(1);
const auto initStateMsdx = INPUT_VARIABLE(2);
auto update = OUTPUT_VARIABLE(0);
auto stateMsg = OUTPUT_VARIABLE(1);
auto stateMsdx = OUTPUT_VARIABLE(2);
if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateMsg->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateMsdx->getShapeInfo()).c_str());
bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size();
REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!");
double dRho, dEpsilon;
if (block.width() > 3) {
const auto rho = INPUT_VARIABLE(3);
const auto epsilon = INPUT_VARIABLE(4);
REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dRho = rho->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dRho = T_ARG(0);
dEpsilon = T_ARG(1);
}
helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon);
return Status::OK();
}
DECLARE_TYPES(ada_delta_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,77 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initState = INPUT_VARIABLE(1);
auto update = OUTPUT_VARIABLE(0);
auto stateH = OUTPUT_VARIABLE(1);
if (gradient->isEmpty() || initState->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size();
REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!");
double dLr, dEpsilon;
if (block.width() > 2) {
const auto lr = INPUT_VARIABLE(2);
const auto epsilon = INPUT_VARIABLE(3);
REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dEpsilon = T_ARG(1);
}
helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon);
return Status::OK();
}
DECLARE_TYPES(ada_grad_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,93 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateU = INPUT_VARIABLE(1);
const auto initStateM = INPUT_VARIABLE(2);
auto update = OUTPUT_VARIABLE(0);
auto stateU = OUTPUT_VARIABLE(1);
auto stateM = OUTPUT_VARIABLE(2);
// todo maybe we need an error like on Java side
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
double dLr, dBeta1, dBeta2, dEpsilon;
if (block.width() > 3) {
const auto lr = INPUT_VARIABLE(3);
const auto beta1 = INPUT_VARIABLE(4);
const auto beta2 = INPUT_VARIABLE(5);
const auto epsilon = INPUT_VARIABLE(6);
REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dBeta1 = beta1->e<double>(0);
dBeta2 = beta2->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dBeta1 = T_ARG(1);
dBeta2 = T_ARG(2);
dEpsilon = T_ARG(3);
}
helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
return Status::OK();
}
DECLARE_TYPES(ada_max_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,92 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateU = INPUT_VARIABLE(1);
const auto initStateM = INPUT_VARIABLE(2);
auto update = OUTPUT_VARIABLE(0);
auto stateU = OUTPUT_VARIABLE(1);
auto stateM = OUTPUT_VARIABLE(2);
// todo maybe we need an error like on Java side
if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
double dLr, dBeta1, dBeta2, dEpsilon;
if (block.width() > 3) {
const auto lr = INPUT_VARIABLE(3);
const auto beta1 = INPUT_VARIABLE(4);
const auto beta2 = INPUT_VARIABLE(5);
const auto epsilon = INPUT_VARIABLE(6);
REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dBeta1 = beta1->e<double>(0);
dBeta2 = beta2->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dBeta1 = T_ARG(1);
dBeta2 = T_ARG(2);
dEpsilon = T_ARG(3);
}
helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration);
return Status::OK();
}
DECLARE_TYPES(adam_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,98 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateV = INPUT_VARIABLE(1);
const auto initStateM = INPUT_VARIABLE(2);
const auto initStateH = INPUT_VARIABLE(3);
auto update = OUTPUT_VARIABLE(0);
auto stateV = OUTPUT_VARIABLE(1);
auto stateM = OUTPUT_VARIABLE(2);
auto stateH = OUTPUT_VARIABLE(3);
// todo maybe we need an error like on Java side
if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateH->getShapeInfo()).c_str());
bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size();
auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
double dLr, dBeta1, dBeta2, dEpsilon;
if (block.width() > 4) {
const auto lr = INPUT_VARIABLE(4);
const auto beta1 = INPUT_VARIABLE(5);
const auto beta2 = INPUT_VARIABLE(6);
const auto epsilon = INPUT_VARIABLE(7);
REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dBeta1 = beta1->e<double>(0);
dBeta2 = beta2->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dBeta1 = T_ARG(1);
dBeta2 = T_ARG(2);
dEpsilon = T_ARG(3);
}
helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH,
*update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration);
return Status::OK();
}
DECLARE_TYPES(ams_grad_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,92 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initStateV = INPUT_VARIABLE(1);
const auto initStateM = INPUT_VARIABLE(2);
auto update = OUTPUT_VARIABLE(0);
auto stateV = OUTPUT_VARIABLE(1);
auto stateM = OUTPUT_VARIABLE(2);
// todo maybe we need an error like on Java side
if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str());
REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str());
bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size();
auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!");
double dLr, dBeta1, dBeta2, dEpsilon;
if (block.width() > 3) {
const auto lr = INPUT_VARIABLE(3);
const auto beta1 = INPUT_VARIABLE(4);
const auto beta2 = INPUT_VARIABLE(5);
const auto epsilon = INPUT_VARIABLE(6);
REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf());
REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dBeta1 = beta1->e<double>(0);
dBeta2 = beta2->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dBeta1 = T_ARG(1);
dBeta2 = T_ARG(2);
dEpsilon = T_ARG(3);
}
helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration);
return Status::OK();
}
DECLARE_TYPES(nadam_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,75 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initState = INPUT_VARIABLE(1);
auto update = OUTPUT_VARIABLE(0);
auto stateV = OUTPUT_VARIABLE(1);
if (gradient->isEmpty() || initState->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size();
REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!");
double dLr, dMomentum;
if (block.width() > 2) {
const auto lr = INPUT_VARIABLE(2);
const auto momentum = INPUT_VARIABLE(3);
REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf());
dLr = lr->e<double>(0);
dMomentum = momentum->e<double>(0);
}
else {
dLr = T_ARG(0);
dMomentum = T_ARG(1);
}
helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum);
return Status::OK();
}
DECLARE_TYPES(nesterovs_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,80 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) {
const auto gradient = INPUT_VARIABLE(0);
const auto initState = INPUT_VARIABLE(1);
auto update = OUTPUT_VARIABLE(0);
auto stateG = OUTPUT_VARIABLE(1);
if (gradient->isEmpty() || initState->isEmpty())
return Status::OK();
REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient,"
" expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(),
ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str());
bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size();
REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!");
double dLr, dRmsDecay, dEpsilon;
if (block.width() > 2) {
const auto lr = INPUT_VARIABLE(2);
const auto rmsDecay = INPUT_VARIABLE(3);
const auto epsilon = INPUT_VARIABLE(4);
REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf());
REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf());
dLr = lr->e<double>(0);
dRmsDecay = rmsDecay->e<double>(0);
dEpsilon = epsilon->e<double>(0);
}
else {
dLr = T_ARG(0);
dRmsDecay = T_ARG(1);
dEpsilon = T_ARG(2);
}
helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon);
return Status::OK();
}
DECLARE_TYPES(rms_prop_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,61 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/updaters.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) {
const auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (input->isEmpty())
return Status::OK();
bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size();
REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!");
if (block.width() > 1) {
const auto lr = INPUT_VARIABLE(1);
REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf());
input->applyScalarArr(scalar::Multiply, *lr, *output);
}
else {
input->applyScalar(scalar::Multiply, T_ARG(0), *output);
}
return Status::OK();
}
DECLARE_TYPES(sgd_updater) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,210 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#ifndef LIBND4J_HEADERS_UPDATERS_H
#define LIBND4J_HEADERS_UPDATERS_H
#include <ops/declarable/headers/common.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <ops/declarable/helpers/updatersHelpers.h>
namespace sd {
namespace ops {
/**
* SGD updater
* Input arrays:
* 0 - input array with gradients.
* Optional:
* 1 - scalar learning rate value
* Optional:
* T args
* 0 - scalar learning rate value
*/
#if NOT_EXCLUDED(OP_sgd_updater)
DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0);
#endif
/**
* RmsPropUpdater updater
* Input arrays:
* 0 - input array with gradients.
* 1 - Initial state
* Optional:
* 2 - scalar learning rate value
* 3 - scalar rms decay
* 4 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - scalar rms decay
* 2 - epsilon
*/
#if NOT_EXCLUDED(OP_rms_prop_updater)
DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0);
#endif
// AdaGrad
/* Input arrays :
* 0 - input array with gradients.
* 1 - historical grad state
* Optional :
* 2 - scalar learning rate value
* 3 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - epsilon
*/
#if NOT_EXCLUDED(OP_ada_grad_updater)
DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0);
#endif
// AdaMax
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V
* 2 - gradient state M
* Optional :
* 3 - scalar learning rate value
* 4 - beta 1 value
* 5 - beta 2 value
* 6 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - beta 1 value
* 2 - beta 2 value
* 3 - epsilon
* Optional:
* I args
* 0 - iteration
*/
#if NOT_EXCLUDED(OP_ada_max_updater)
DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0);
#endif
// Nesterov's momentum
/* Input arrays :
* 0 - input array with gradients.
* 1 - V grad state
* Optional :
* 2 - scalar learning rate value
* 3 - scalar momentum value
* Optional:
* T args
* 0 - learning rate value
* 1 - momentum value
*/
#if NOT_EXCLUDED(OP_nesterovs_updater)
DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0);
#endif
// Adam
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V
* 2 - gradient state M
* Optional :
* 3 - scalar learning rate value
* 4 - beta 1 value
* 5 - beta 2 value
* 6 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - beta 1 value
* 2 - beta 2 value
* 3 - epsilon
* Optional:
* I args
* 0 - iteration
*/
#if NOT_EXCLUDED(OP_adam_updater)
DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0);
#endif
// AdaDelta
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V
* 2 - gradient state M
* Optional :
* 3 - rho value
* 6 - epsilon
* Optional:
* T args
* 0 - rho
* 1 - epsilon
*/
#if NOT_EXCLUDED(OP_ada_delta_updater)
DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0);
#endif
// Nadam
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V
* 2 - gradient state M
* Optional :
* 3 - scalar learning rate value
* 4 - beta 1 value
* 5 - beta 2 value
* 6 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - beta 1 value
* 2 - beta 2 value
* 3 - epsilon
* Optional:
* I args
* 0 - iteration
*/
#if NOT_EXCLUDED(OP_nadam_updater)
DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0);
#endif
// AmsGrad
/* Input arrays :
* 0 - input array with gradients.
* 1 - gradient state V - sqrd gradients
* 2 - gradient state M - moving avg
* 3 - gradient state H - max
* Optional :
* 4 - scalar learning rate value
* 5 - beta 1 value
* 6 - beta 2 value
* 7 - epsilon
* Optional:
* T args
* 0 - scalar learning rate value
* 1 - beta 1 value
* 2 - beta 2 value
* 3 - epsilon
* Optional:
* I args
* 0 - iteration
*/
#if NOT_EXCLUDED(OP_ams_grad_updater)
DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0);
#endif
}
}
#endif

View File

@ -0,0 +1,108 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
const T* grad = gradient.bufferAsT<T>();
const T* initMsg = initStateMsg.bufferAsT<T>();
const T* initMsdx = initStateMsdx.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stMsg = stateMsg.bufferAsT<T>();
T* stMsdx = stateMsdx.bufferAsT<T>();
const T rho = static_cast<T>(dRho);
const T epsilon = static_cast<T>(dEpsilon);
const T rhoT = (1 - rho);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateMsdx.ordering() &&
stateMsdx.ordering() == initStateMsdx.ordering() &&
stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT;
up[i] = grad[i] * (sd::math::nd4j_sqrt<T, T>(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[i] + epsilon));
stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT;
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsg.getShapeInfo());
bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsg.getShapeInfo());
bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsdx.getShapeInfo());
bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsdx.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < gradient.lengthOf(); i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.getShapeInfo(), coords);
const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.getShapeInfo(), coords);
const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.getShapeInfo(), coords);
const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.getShapeInfo(), coords);
stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT;
up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt<T, T>(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[stMsgOffset] + epsilon));
stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT;
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) {
const T* grad = gradient.bufferAsT<T>();
const T* init = initState.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* st = stateH.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T epsilon = static_cast<T>(dEpsilon);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
st[i] = init[i] + grad[i] * grad[i];
up[i] = (lr * grad[i]) / (math::nd4j_sqrt<T, T>(st[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords);
st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset];
up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH,
const double dLr, const double dEpsilon) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T* grad = gradient.bufferAsT<T>();
const T* initU = initStateU.bufferAsT<T>();
const T* initM = initStateM.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stU = stateU.bufferAsT<T>();
T* stM = stateM.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
T epsilonT = lr / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateU.ordering() &&
stateU.ordering() == initStateU.ordering() &&
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
//m = B_1 * m + (1-B_1)*grad
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
//u = max(B_2 * u, |grad|)
stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32;
up[i] = stM[i] * epsilonT / stU[i];
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo());
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo());
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords);
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords);
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
//m = B_1 * m + (1-B_1)*grad
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
//u = max(B_2 * u, |grad|)
stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32;
up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset];
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update,
NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
const double dEpsilon, const int nIteration) {
const T* grad = gradient.bufferAsT<T>();
const T* initU = initStateU.bufferAsT<T>();
const T* initM = initStateM.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stU = stateU.bufferAsT<T>();
T* stM = stateM.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
const T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
const T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateU.ordering() &&
stateU.ordering() == initStateU.ordering() &&
stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1);
stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2);
up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo());
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo());
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords);
const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords);
const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2);
up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,126 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T* grad = gradient.bufferAsT<T>();
const T* initV = initStateV.bufferAsT<T>();
const T* initM = initStateM.bufferAsT<T>();
const T* initH = initStateH.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stV = stateV.bufferAsT<T>();
T* stM = stateM.bufferAsT<T>();
T* stH = stateH.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
T epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1.0 - sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1)));
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
const T mbeta1 = (1 - beta1);
const T mbeta2 = (1 - beta2);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() &&
1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateV.ordering() &&
stateV.ordering() == initStateV.ordering() &&
stateV.ordering() == initStateM.ordering() &&
stateM.ordering() == initStateM.ordering() &&
stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
stM[i] = beta1 * initM[i] + grad[i] * mbeta1;
stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2;
stH[i] = sd::math::nd4j_max(initH[i], stV[i]);
up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt<T, T>(stH[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo());
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
bool bXInHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateH.getShapeInfo());
bool bXStHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords);
const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.getShapeInfo(), coords);
const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords);
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1;
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]);
up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt<T, T>(stH[stHOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,116 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1,
const double dBeta2, const double dEpsilon, const int nIteration) {
const T* grad = gradient.bufferAsT<T>();
const T* initV = initStateV.bufferAsT<T>();
const T* initM = initStateM.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* stV = stateV.bufferAsT<T>();
T* stM = stateM.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
const T mbeta1T = 1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
const T mbeta1 = (1 - beta1);
const T mbeta2 = (1 - beta2);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() &&
update.ordering() == stateV.ordering() &&
stateV.ordering() == initStateV.ordering() &&
stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
auto oneMinusBeta1Grad = grad[i] * mbeta1;
stM[i] = beta1 * initM[i] + oneMinusBeta1Grad;
stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2;
up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo());
bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo());
bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords);
const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords);
const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords);
auto oneMinusBeta1Grad = grad[xOffset] * mbeta1;
stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad;
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[stVOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
const T* grad = gradient.bufferAsT<T>();
const T* init = initState.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* st = stateV.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T momentum = static_cast<T>(dMomentum);
const T momentumT = (-momentum - 1);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
T prevState = momentum * init[i];
st[i] = prevState - lr * grad[i];
up[i] = prevState + momentumT * st[i];
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords);
T prevState = momentum * init[initOffset];
st[stOffset] = prevState - lr * grad[xOffset];
up[zOffset] = prevState + momentumT * st[stOffset];
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
template <typename T>
static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
const double dLr, const double dRmsDecay, const double dEpsilon) {
const T* grad = gradient.bufferAsT<T>();
const T* init = initState.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* st = stateG.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T rmsDecay = static_cast<T>(dRmsDecay);
const T epsilon = static_cast<T>(dEpsilon);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ;
up[i] = (lr * grad[i]) / ( math::nd4j_sqrt<T, T>(st[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateG.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.getShapeInfo(), coords);
st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ;
up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
const double dLr, const double dRmsDecay, const double dEpsilon) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,129 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo,
const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg,
const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initMsg= reinterpret_cast<const T*>(vinMsg);
const auto initMsdx = reinterpret_cast<const T*>(vinMsdx);
auto up = reinterpret_cast<T*>(vz);
auto stMsg = reinterpret_cast<T*>(vstMsg);
auto stMsdx = reinterpret_cast<T*>(vstMsdx);
__shared__ Nd4jLong xLen;
__shared__ T rhoT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
rhoT = (1 - rho);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) &&
1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) &&
shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) &&
shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo);
bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo);
bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo);
bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i;
if (!bEWS || !bOrdering){
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords);
stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords);
initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords);
stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords);
}
stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT;
up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt<T, T>(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt<T, T>(stMsg[stMsgOffset] + epsilon));
stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT;
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo,
void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) {
const T rho = static_cast<T>(dRho);
const T epsilon = static_cast<T>(dEpsilon);
adaDeltaUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinMsg, inMsgShapeInfo,
vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon);
}
///////////////////////////////////////////////////////////////////
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx,
NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) {
PointersManager manager(context, "adaDeltaUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx });
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initStateMsg.getSpecialBuffer(), initStateMsg.getSpecialShapeInfo(), initStateMsdx.getSpecialBuffer(), initStateMsdx.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(),stateMsg.getSpecialBuffer(), stateMsg.getSpecialShapeInfo(),
stateMsdx.getSpecialBuffer(), stateMsdx.getSpecialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,117 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
const T lr, const T epsilon) {
const auto x = reinterpret_cast<const T*>(vx);
const auto init = reinterpret_cast<const T*>(vin);
auto up = reinterpret_cast<T*>(vz);
auto st = reinterpret_cast<T*>(vst);
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
__shared__ Nd4jLong xLen;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) &&
shape::order(xShapeInfo) == shape::order(inShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
if (!bEWS || !bOrdering) {
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
}
st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset];
up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
const double dLr, const double dEpsilon) {
const T lr = static_cast<T>(dLr);
const T epsilon = static_cast<T>(dEpsilon);
adaGradUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vin, inShapeInfo,
vz, zShapeInfo, vst, stShapeInfo, lr, epsilon);
}
///////////////////////////////////////////////////////////////////
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState,
NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) {
PointersManager manager(context, "adaGradUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState });
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(),
gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,142 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initU = reinterpret_cast<const T*>(vinv);
const auto initM = reinterpret_cast<const T*>(vinm);
auto up = reinterpret_cast<T*>(vz);
auto stU = reinterpret_cast<T*>(vstV);
auto stM = reinterpret_cast<T*>(vstM);
__shared__ Nd4jLong xLen;
__shared__ T beta1T, epsilonT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
beta1T = sd::math::nd4j_pow<T,T,T>(beta1, (iteration + 1) );
epsilonT = lr / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) &&
shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) &&
shape::order(xShapeInfo) == shape::order(stvShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
if (!bEWS || !bOrdering) {
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
}
//m = B_1 * m + (1-B_1)*grad
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
//u = max(B_2 * u, |grad|)
stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32;
up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr,
const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
adaMaxUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz,
zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
}
///////////////////////////////////////////////////////////////////
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1,
const double dBeta2, const double dEpsilon, const int nIteration) {
PointersManager manager(context, "adaMaxUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(),
gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), initStateU.getSpecialBuffer(),
initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(),
stateU.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(),
dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,139 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV,
const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initU = reinterpret_cast<const T*>(vinv);
const auto initM = reinterpret_cast<const T*>(vinm);
auto up = reinterpret_cast<T*>(vz);
auto stU = reinterpret_cast<T*>(vstV);
auto stM = reinterpret_cast<T*>(vstM);
__shared__ Nd4jLong xLen;
__shared__ T epsilonT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
if (!bEWS || !bOrdering){
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
}
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2);
up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo,
void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
adamUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
}
///////////////////////////////////////////////////////////////////
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM,
NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2,
const double dEpsilon, const int nIteration) {
PointersManager manager(context, "adamUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initStateU.getSpecialBuffer(), initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), stateU.getSpecialShapeInfo(),
stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,152 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM,
const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo,
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initV = reinterpret_cast<const T*>(vinv);
const auto initM = reinterpret_cast<const T*>(vinm);
const auto initH = reinterpret_cast<const T*>(vinh);
auto up = reinterpret_cast<T*>(vz);
auto stV = reinterpret_cast<T*>(vstV);
auto stM = reinterpret_cast<T*>(vstM);
auto stH = reinterpret_cast<T*>(vstH);
__shared__ Nd4jLong xLen;
__shared__ T mbeta1, mbeta2, epsilonT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
epsilonT = lr * sd::math::nd4j_sqrt<T, T>(1.0 - sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1)));
if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
epsilonT = epsilon;
mbeta1 = (1 - beta1);
mbeta2 = (1 - beta2);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) &&
1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) &&
shape::order(sthShapeInfo) == shape::order(inhShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo);
bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i;
if (!bEWS || !bOrdering){
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords);
stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords);
}
stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1;
stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]);
up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt<T, T>(stH[stHOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
amsGradUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration);
}
///////////////////////////////////////////////////////////////////
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH,
NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
PointersManager manager(context, "amsGradUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH });
BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
initStateH.getSpecialBuffer(), initStateH.getSpecialShapeInfo(), update.getSpecialBuffer(), update.getSpecialShapeInfo(),
stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(),
stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,137 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo,
const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto initV = reinterpret_cast<const T*>(vinv);
const auto initM = reinterpret_cast<const T*>(vinm);
auto up = reinterpret_cast<T*>(vz);
auto stV = reinterpret_cast<T*>(vstV);
auto stM = reinterpret_cast<T*>(vstM);
__shared__ Nd4jLong xLen;
__shared__ T mbeta1T, mbeta1, mbeta2;
__shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
mbeta1T = 1.0 - sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
mbeta1 = (1 - beta1);
mbeta2 = (1 - beta2);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
shape::order(stvShapeInfo) == shape::order(invShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
if (!bEWS || !bOrdering){
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
}
auto oneMinusBeta1Grad = grad[xOffset] * mbeta1;
stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad;
stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2;
up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt<T, T>(stV[stUOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM,
const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
const T lr = static_cast<T>(dLr);
const T beta1 = static_cast<T>(dBeta1);
const T beta2 = static_cast<T>(dBeta2);
const T epsilon = static_cast<T>(dEpsilon);
const T iteration = static_cast<T>(nIteration);
nadamUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo,
vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
}
///////////////////////////////////////////////////////////////////
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM,
NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
PointersManager manager(context, "nadamUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM });
BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(),
stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,117 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) {
const auto grad = reinterpret_cast<const T*>(vx);
const auto init = reinterpret_cast<const T*>(vin);
auto up = reinterpret_cast<T*>(vz);
auto st = reinterpret_cast<T*>(vst);
__shared__ Nd4jLong xLen;
__shared__ T momentumT;
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
momentumT = (-momentum - 1);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) &&
shape::order(xShapeInfo) == shape::order(stShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
if (!bEWS || !bOrdering) {
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
}
T prevState = momentum * init[initOffset];
st[stOffset] = prevState - lr * grad[xOffset];
up[zOffset] = prevState + momentumT * st[stOffset];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
const double dLr, const double dMomentum) {
const T lr = static_cast<T>(dLr);
const T momentum = static_cast<T>(dMomentum);
nesterovsUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vin, inShapeInfo,
vz, zShapeInfo, vst, stShapeInfo, lr, momentum);
}
///////////////////////////////////////////////////////////////////
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState,
NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
PointersManager manager(context, "nesterovsUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState });
BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock,
context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), dLr, dMomentum), FLOAT_TYPES);
NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState });
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,121 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <system/op_boilerplate.h>
#include <ops/declarable/helpers/updatersHelpers.h>
#include <helpers/PointersManager.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo,
void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
const T lr, const T rmsDecay, const T epsilon) {
const auto x = reinterpret_cast<const T*>(vx);
const auto init = reinterpret_cast<const T*>(vin);
auto up = reinterpret_cast<T*>(vz);
auto st = reinterpret_cast<T*>(vst);
__shared__ Nd4jLong xLen;
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) &&
shape::order(xShapeInfo) == shape::order(inShapeInfo);
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
}
__syncthreads();
int coords[MAX_RANK];
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
if (!bEWS || !bOrdering) {
shape::index2coords(i, xShapeInfo, coords);
xOffset = shape::getOffset(xShapeInfo, coords);
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
}
st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ;
up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo,
void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
const double dLr, const double dRmsDecay, const double dEpsilon) {
const T lr = static_cast<T>(dLr);
const T rmsDecay = static_cast<T>(dRmsDecay);
const T epsilon = static_cast<T>(dEpsilon);
rmsPropUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vin, inShapeInfo,
vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon);
}
///////////////////////////////////////////////////////////////////
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG,
const double dLr, const double dRmsDecay, const double dEpsilon) {
PointersManager manager(context, "rmsPropUpdater");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState });
BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock,
context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
initState.getSpecialBuffer(), initState.getSpecialShapeInfo(),
update.getSpecialBuffer(), update.getSpecialShapeInfo(),
stateG.getSpecialBuffer(), stateG.getSpecialShapeInfo(),
dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES);
NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState});
manager.synchronize();
}
}
}
}

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#ifndef LIBND4J_UPDATER_RMS_PROM_H
#define LIBND4J_UPDATER_RMS_PROM_H
#include <system/op_boilerplate.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
namespace helpers {
void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon);
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon);
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum);
void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon);
void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration);
}
}
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -43,6 +43,15 @@ public class ImportClassMapping {
private static final List<Class<?>> fnClasses = Arrays.<Class<?>>asList(
org.nd4j.linalg.api.ops.DynamicCustomOp.class,
org.nd4j.linalg.api.ops.NoOp.class,
org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater.class,
org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater.class,
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class AdaDeltaUpdater extends DynamicCustomOp {
public AdaDeltaUpdater() {
//
}
public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, double rho, double epsilon) {
this(gradients, stateMsg, stateMsdx, gradients, stateMsg, stateMsdx, rho, epsilon);
}
public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, @NonNull INDArray updates, @NonNull INDArray updatedStateMsg, @NonNull INDArray updatedStateMsdx, double rho, double epsilon) {
addInputArgument(gradients, stateMsg, stateMsdx);
addOutputArgument(updates, updatedStateMsg, updatedStateMsdx);
addTArgument(rho, epsilon);
}
@Override
public String opName() {
return "ada_delta_updater";
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class AdaGradUpdater extends DynamicCustomOp {
public AdaGradUpdater() {
//
}
public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double epsilon) {
this(gradients, state, gradients, state, lr, epsilon);
}
public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double epsilon) {
addInputArgument(gradients, state);
addOutputArgument(updates, updatedState);
addTArgument(lr, epsilon);
}
@Override
public String opName() {
return "ada_grad_updater";
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class AdaMaxUpdater extends DynamicCustomOp {
public AdaMaxUpdater() {
//
}
public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
}
public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateU, stateM);
addOutputArgument(updates, updatedStateU, updatedStateM);
addTArgument(lr, beta1, beta2, epsilon);
addIArgument(iteration);
}
@Override
public String opName() {
return "ada_max_updater";
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class AdamUpdater extends DynamicCustomOp {
public AdamUpdater() {
//
}
public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration);
}
public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateU, stateM);
addOutputArgument(updates, updatedStateU, updatedStateM);
addTArgument(lr, beta1, beta2, epsilon);
addIArgument(iteration);
}
@Override
public String opName() {
return "adam_updater";
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class AmsGradUpdater extends DynamicCustomOp {
public AmsGradUpdater() {
//
}
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration);
}
public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateV, stateM, stateH);
addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH);
addTArgument(lr, beta1, beta2, epsilon);
addIArgument(iteration);
}
@Override
public String opName() {
return "ams_grad_updater";
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class NadamUpdater extends DynamicCustomOp {
public NadamUpdater() {
//
}
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration);
}
public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) {
addInputArgument(gradients, stateV, stateM);
addOutputArgument(updates, updatedStateV, updatedStateM);
addTArgument(lr, beta1, beta2, epsilon);
addIArgument(iteration);
}
@Override
public String opName() {
return "nadam_updater";
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class NesterovsUpdater extends DynamicCustomOp {
public NesterovsUpdater() {
//
}
public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double momentum) {
this(gradients, state, gradients, state, lr, momentum);
}
public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double momentum) {
addInputArgument(gradients, state);
addOutputArgument(updates, updatedState);
addTArgument(lr, momentum);
}
@Override
public String opName() {
return "nesterovs_updater";
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class RmsPropUpdater extends DynamicCustomOp {
public RmsPropUpdater() {
//
}
public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double decay, double epsilon) {
this(gradients, state, gradients, state, lr, decay, epsilon);
}
public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double decay, double epsilon) {
addInputArgument(gradients, state);
addOutputArgument(updates, updatedState);
addTArgument(lr, decay, epsilon);
}
@Override
public String opName() {
return "rms_prop_updater";
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.updaters;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
*
* @author raver119@gmail.com
*/
public class SgdUpdater extends DynamicCustomOp {
public SgdUpdater() {
//
}
public SgdUpdater(@NonNull INDArray input, double lr) {
this(input, input, lr);
}
public SgdUpdater(@NonNull INDArray input, @NonNull INDArray output, double lr) {
addInputArgument(input);
addOutputArgument(output);
addTArgument(lr);
}
@Override
public String opName() {
return "sgd_updater";
}
}

View File

@ -10686,6 +10686,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <ops/declarable/headers/util.h>
// #include <ops/declarable/headers/BarnesHutTsne.h>
// #include <ops/declarable/headers/images.h>
// #include <ops/declarable/headers/updaters.h>
// #include <system/dll.h>
// #include <helpers/shape.h>
// #include <helpers/TAD.h>

View File

@ -12422,6 +12422,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <ops/declarable/headers/util.h>
// #include <ops/declarable/headers/BarnesHutTsne.h>
// #include <ops/declarable/headers/images.h>
// #include <ops/declarable/headers/updaters.h>
// #include <system/dll.h>
// #include <helpers/shape.h>
// #include <helpers/TAD.h>

View File

@ -15,10 +15,12 @@
******************************************************************************/
package org.nd4j.linalg.learning;
import lombok.val;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.*;
@ -58,14 +60,23 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val msgu = msg.dup();
val msdxu = msdx.dup();
UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(g3, msgu, msdxu, rho, epsilon));
assertEquals(msg, state.get("msg"));
assertEquals(msdx, state.get("msdx"));
assertEquals(g1, g2);
assertEquals(msg, msgu);
assertEquals(msdx, msdxu);
assertEquals(g1, g3);
}
}
@ -85,13 +96,20 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val su = s.dup();
UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(g3, su, lr, epsilon));
assertEquals(s, state.get("grad"));
assertEquals(g1, g2);
assertEquals(s, su);
assertEquals(g1, g3);
}
}
@ -118,14 +136,23 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val mu = m.dup();
val vu = v.dup();
UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
assertEquals(m, mu);
assertEquals(v, vu);
assertEquals(g1, g3);
}
}
@ -150,14 +177,23 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val mu = m.dup();
val vu = v.dup();
UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
assertEquals(m, mu);
assertEquals(v, vu);
assertEquals(g1, g3);
}
}
@ -185,15 +221,26 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val mu = m.dup();
val vu = v.dup();
val hu = vH.dup();
UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new AmsGradUpdater(g3, vu, mu, hu, lr, beta1, beta2, eps, i));
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(vH, state.get("V_HAT"));
assertEquals(g1, g2);
assertEquals(m, mu);
assertEquals(v, vu);
assertEquals(vH, hu);
assertEquals(g1, g3);
}
}
@ -219,14 +266,23 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val vu = v.dup();
val mu = m.dup();
UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i));
assertEquals(m, state.get("M"));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
assertEquals(m, mu);
assertEquals(v, vu);
assertEquals(g1, g3);
}
}
@ -247,13 +303,18 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val vu = v.dup();
UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(g3, vu, lr, momentum));
assertEquals(v, state.get("V"));
assertEquals(g1, g2);
assertEquals(v, vu);
assertEquals(g1, g3);
}
}
@ -275,13 +336,19 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
val gu = g.dup();
UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps);
u.applyUpdater(g2, i, 0);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(g3, gu, lr,decay, eps));
assertEquals(g, state.get("G"));
assertEquals(g1, g2);
assertEquals(g, gu);
assertEquals(g1, g3);
}
}
@ -294,11 +361,14 @@ public class UpdaterValidation extends BaseNd4jTest {
for( int i=0; i<3; i++ ) {
INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5);
INDArray g2 = g1.dup();
val g3 = g1.dup();
UpdaterJavaCode.applySgd(g1, lr);
Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(g3, lr));
u.applyUpdater(g2, i, 0);
assertEquals(g1, g2);
assertEquals(g1, g3);
}
}