From 69c92ca5ae517bcab4f0aad0eec92b128b01644d Mon Sep 17 00:00:00 2001 From: Oleh Date: Mon, 23 Mar 2020 06:28:31 +0200 Subject: [PATCH] Learning updaters for gradient (#335) * libnd4j raw implementation of sgd upader Signed-off-by: Oleg * libnd4j some corrections and simple test added Signed-off-by: Oleg * libnd4j some corrections after discussion Signed-off-by: Oleg * libnd4j integrate applyScalar Signed-off-by: Oleg * libnd4j raw implementation of rmsPropUpdater on cpu Signed-off-by: Oleg * libnd4j fix operations declaration Signed-off-by: Oleg * libnd4j rmsPropUpdater added, test cases for sgd, etc Signed-off-by: Oleg * libnd4j fixed several typos Signed-off-by: Oleg * libnd4j some fixes and improvements for rmsPropUpdater based on Java tests Signed-off-by: Oleg * libnd4j fixed cuda implementation, update tests and corrected behavior according java tests Signed-off-by: Oleg * libnd4j adaGrad updater added Signed-off-by: Oleg * libnd4j one minor fix for ada grad Signed-off-by: Oleg * libnd4j several more fixes for ada_grad Signed-off-by: Oleg * libnd4j nesterovs updater added Signed-off-by: Oleg * libnd4j fixed nesterovs updater behavior, several typos and rename file Signed-off-by: Oleg * libnd4j one minor typo Signed-off-by: Oleg * libnd4j ada max updater added Signed-off-by: Oleg * libnd4j fixed several typos in adaMax updater Signed-off-by: Oleg * libnd4j fixed several typos in adaMaxUpdater Signed-off-by: Oleg * libnd4j several fixes for adaMax, added Adam Updater Signed-off-by: Oleg * libnd4j adaDeltaUpdater added, minor fixes for adamUpdater Signed-off-by: Oleg * libnd4j several fixes for adaDeltaUpdater Signed-off-by: Oleg * libnd4j nadamUpdater added Signed-off-by: Oleg * libnd4j one more correction for nadam updater Signed-off-by: Oleg * libnd4j several fixes for nadam updater and added amsGradUpdater Signed-off-by: Oleg * libnd4j several typos fixed in amsGradUpdater Signed-off-by: Oleg * libnd4j some corrections and added f order support rmsProp updater Signed-off-by: Oleg * libnd4j added support of f order for all updaters and modify tests for testing in place Signed-off-by: Oleg * libnd4j fixed issues for updates when not in place mode used, added tests for f order Signed-off-by: Oleg * libnd4j added input shape checks Signed-off-by: Oleg * libnd4j some corrections for different cases handling Signed-off-by: Oleg * libnd4j some code clean up and optimize per request Signed-off-by: Oleg * libnd4j updaters refactoring after review Signed-off-by: Oleg * SgdUpdater wrapper Signed-off-by: raver119 * first test Signed-off-by: raver119 * RmsPropUpdater added Signed-off-by: raver119 * NadamUpdater + NesterovsUpdater Signed-off-by: raver119 * AmsGradUpdater Signed-off-by: raver119 * AdamUpdater added Signed-off-by: raver119 * AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater Signed-off-by: raver119 * AdaGradUpdater test added Signed-off-by: raver119 * libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc Signed-off-by: Oleg * libnd4j next step to split operations implementation into separate files Signed-off-by: Oleg * libnd4j merge master and minor corrections Signed-off-by: Oleg * libnd4j revert some changes of split implementation Signed-off-by: Oleg * libnd4j forgot to add header file Signed-off-by: Oleg * public default constructors Signed-off-by: raver119 * ImportClassMapping updated Signed-off-by: raver119 Co-authored-by: raver119 --- .../include/ops/declarable/CustomOperations.h | 1 + .../generic/updaters/adaDeltaUpdater.cpp | 81 ++ .../generic/updaters/adaGradUpdater.cpp | 77 ++ .../generic/updaters/adaMaxUpdater.cpp | 93 ++ .../generic/updaters/adamUpdater.cpp | 92 ++ .../generic/updaters/amsGradUpdater.cpp | 98 ++ .../generic/updaters/nadamUpdater.cpp | 92 ++ .../generic/updaters/nesterovsUpdater.cpp | 75 + .../generic/updaters/rmsPropUpdater.cpp | 80 ++ .../generic/updaters/sgdUpdater.cpp | 61 + .../include/ops/declarable/headers/updaters.h | 210 +++ .../helpers/cpu/updaterAdaDelta.cpp | 108 ++ .../declarable/helpers/cpu/updaterAdaGrad.cpp | 91 ++ .../declarable/helpers/cpu/updaterAdaMax.cpp | 113 ++ .../declarable/helpers/cpu/updaterAdam.cpp | 113 ++ .../declarable/helpers/cpu/updaterAmsGrad.cpp | 126 ++ .../declarable/helpers/cpu/updaterNadam.cpp | 116 ++ .../helpers/cpu/updaterNesterovs.cpp | 91 ++ .../declarable/helpers/cpu/updaterRmsProp.cpp | 91 ++ .../helpers/cuda/updaterAdaDelta.cu | 129 ++ .../declarable/helpers/cuda/updaterAdaGrad.cu | 117 ++ .../declarable/helpers/cuda/updaterAdaMax.cu | 142 ++ .../declarable/helpers/cuda/updaterAdam.cu | 139 ++ .../declarable/helpers/cuda/updaterAmsGrad.cu | 152 ++ .../declarable/helpers/cuda/updaterNadam.cu | 137 ++ .../helpers/cuda/updaterNesterovs.cu | 117 ++ .../declarable/helpers/cuda/updaterRmsProp.cu | 121 ++ .../ops/declarable/helpers/updatersHelpers.h | 44 + .../layers_tests/DeclarableOpsTests18.cpp | 1229 +++++++++++++++++ .../converters/ImportClassMapping.java | 9 + .../ops/impl/updaters/AdaDeltaUpdater.java | 47 + .../api/ops/impl/updaters/AdaGradUpdater.java | 47 + .../api/ops/impl/updaters/AdaMaxUpdater.java | 48 + .../api/ops/impl/updaters/AdamUpdater.java | 48 + .../api/ops/impl/updaters/AmsGradUpdater.java | 48 + .../api/ops/impl/updaters/NadamUpdater.java | 48 + .../ops/impl/updaters/NesterovsUpdater.java | 47 + .../api/ops/impl/updaters/RmsPropUpdater.java | 47 + .../api/ops/impl/updaters/SgdUpdater.java | 47 + .../java/org/nd4j/nativeblas/Nd4jCuda.java | 1 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 1 + .../linalg/learning/UpdaterValidation.java | 74 +- 42 files changed, 4646 insertions(+), 2 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp create mode 100644 libnd4j/include/ops/declarable/headers/updaters.h create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu create mode 100644 libnd4j/include/ops/declarable/helpers/updatersHelpers.h create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 1a1624c08..f98deb784 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp new file mode 100644 index 000000000..bab205543 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaDeltaUpdater.cpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_delta_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateMsg = INPUT_VARIABLE(1); + const auto initStateMsdx = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateMsg = OUTPUT_VARIABLE(1); + auto stateMsdx = OUTPUT_VARIABLE(2); + + if (gradient->isEmpty() || initStateMsg->isEmpty() || initStateMsdx->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateMsg), 0, "ADA_DELTA UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsg->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateMsdx), 0, "ADA_DELTA UPDATER OP: input state Msdx must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateMsdx->getShapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_DELTA UPDATER OP: Rho and epsilon were not provided!"); + + double dRho, dEpsilon; + + if (block.width() > 3) { + const auto rho = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(rho->isScalar(), 0, "ADA_DELTA UPDATER OP: Rho has to be a scalar, but instead got rank %i!", rho->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_DELTA UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dRho = rho->e(0); + dEpsilon = epsilon->e(0); + } + else { + dRho = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaDelta(block.launchContext(), *gradient, *initStateMsg, *initStateMsdx, *update, *stateMsg, *stateMsdx, dRho, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(ada_delta_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp new file mode 100644 index 000000000..a7a92b410 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaGradUpdater.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_grad_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateH = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "ADA_GRAD UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + + bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_GRAD UPDATER OP: learning rate and epsilon were not provided!"); + + double dLr, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto epsilon = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADA_GRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_GRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dEpsilon = T_ARG(1); + } + + helpers::updaterAdaGrad(block.launchContext(), *gradient, *initState, *update, *stateH, dLr, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(ada_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp new file mode 100644 index 000000000..4e34c24f6 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adaMaxUpdater.cpp @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ada_max_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADA_MAX UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADA_MAX UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + int iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "ADA_MAX UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADA_MAX UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "ADA_MAX UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "ADA_MAX UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADA_MAX UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdaMax(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(ada_max_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp new file mode 100644 index 000000000..a696d2388 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/adamUpdater.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(adam_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateU = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateU = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateU->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateU), 0, "ADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateU->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "ADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "ADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "ADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "ADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "ADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "ADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAdam(block.launchContext(), *gradient, *initStateU, *initStateM, *update, *stateU, *stateM, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(adam_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp new file mode 100644 index 000000000..bc0f4beac --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/amsGradUpdater.cpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(ams_grad_updater, 4, 4, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + const auto initStateH = INPUT_VARIABLE(3); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + auto stateH = OUTPUT_VARIABLE(3); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty() || initStateH->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "AMSGRAD UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateH), 0, "AMSGRAD UPDATER OP: input state Msdx must have the same shape as gradient!," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateH->getShapeInfo()).c_str()); + + bool bParamsSupply = 8 == block.width() || 4 == block.getTArguments()->size(); + + auto iteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "AMSGRAD UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 4) { + const auto lr = INPUT_VARIABLE(4); + const auto beta1 = INPUT_VARIABLE(5); + const auto beta2 = INPUT_VARIABLE(6); + const auto epsilon = INPUT_VARIABLE(7); + + REQUIRE_TRUE(lr->isScalar(), 0, "AMSGRAD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "AMSGRAD UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "AMSGRAD UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "AMSGRAD UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterAmsGrad(block.launchContext(), *gradient, *initStateV, *initStateM, *initStateH, + *update, *stateV, *stateM, *stateH, dLr, dBeta1, dBeta2, dEpsilon, iteration); + return Status::OK(); + } + + DECLARE_TYPES(ams_grad_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp new file mode 100644 index 000000000..c6af0686b --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/nadamUpdater.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(nadam_updater, 3, 3, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initStateV = INPUT_VARIABLE(1); + const auto initStateM = INPUT_VARIABLE(2); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + auto stateM = OUTPUT_VARIABLE(2); + + // todo maybe we need an error like on Java side + if (gradient->isEmpty() || initStateV->isEmpty() || initStateM->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initStateM), 0, "NADAM UPDATER OP: input state M must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateM->getShapeInfo()).c_str()); + REQUIRE_TRUE(gradient->isSameShape(initStateV), 0, "NADAM UPDATER OP: input state V must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initStateV->getShapeInfo()).c_str()); + + bool bParamsSupply = 7 == block.width() || 4 == block.getTArguments()->size(); + + auto nIteration = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; + + REQUIRE_TRUE(bParamsSupply, 0, "NADAM UPDATER OP: learning rate, beta 1, beta 2 and epsilon were not provided!"); + + double dLr, dBeta1, dBeta2, dEpsilon; + + if (block.width() > 3) { + const auto lr = INPUT_VARIABLE(3); + const auto beta1 = INPUT_VARIABLE(4); + const auto beta2 = INPUT_VARIABLE(5); + const auto epsilon = INPUT_VARIABLE(6); + + REQUIRE_TRUE(lr->isScalar(), 0, "NADAM UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(beta1->isScalar(), 0, "NADAM UPDATER OP: beta 1 has to be a scalar, but instead got rank %i!", beta1->rankOf()); + REQUIRE_TRUE(beta2->isScalar(), 0, "NADAM UPDATER OP: beta 2 has to be a scalar, but instead got rank %i!", beta2->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "NADAM UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dBeta1 = beta1->e(0); + dBeta2 = beta2->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dBeta1 = T_ARG(1); + dBeta2 = T_ARG(2); + dEpsilon = T_ARG(3); + } + + helpers::updaterNadam(block.launchContext(), *gradient, *initStateV, *initStateM, *update, *stateV, *stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration); + return Status::OK(); + } + + DECLARE_TYPES(nadam_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp new file mode 100644 index 000000000..c77abd448 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/nesterovsUpdater.cpp @@ -0,0 +1,75 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(nesterovs_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateV = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "NESTEROVS UPDATER OP: input state Msg must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + bool bParamsSupply = 4 == block.width() || 2 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "NESTEROVS UPDATER OP: learning rate and momentum were not provided!"); + + double dLr, dMomentum; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto momentum = INPUT_VARIABLE(3); + + REQUIRE_TRUE(lr->isScalar(), 0, "NESTEROVS UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(momentum->isScalar(), 0, "NESTEROVS UPDATER OP: Momentum has to be a scalar, but instead got rank %i!", momentum->rankOf()); + + dLr = lr->e(0); + dMomentum = momentum->e(0); + } + else { + dLr = T_ARG(0); + dMomentum = T_ARG(1); + } + helpers::updaterNesterovs(block.launchContext(), *gradient, *initState, *update, *stateV, dLr, dMomentum); + return Status::OK(); + } + + DECLARE_TYPES(nesterovs_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp new file mode 100644 index 000000000..1ca318e26 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/rmsPropUpdater.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(rms_prop_updater, 2, 2, true, 0, 0) { + + const auto gradient = INPUT_VARIABLE(0); + const auto initState = INPUT_VARIABLE(1); + + auto update = OUTPUT_VARIABLE(0); + auto stateG = OUTPUT_VARIABLE(1); + + if (gradient->isEmpty() || initState->isEmpty()) + return Status::OK(); + + REQUIRE_TRUE(gradient->isSameShape(initState), 0, "RMS_PROB UPDATER OP: input state must have the same shape as gradient," + " expected shape %s, but got %s!", ShapeUtils::shapeAsString(gradient->getShapeInfo()).c_str(), + ShapeUtils::shapeAsString(initState->getShapeInfo()).c_str()); + + bool bParamsSupply = 5 == block.width() || 3 == block.getTArguments()->size(); + + REQUIRE_TRUE(bParamsSupply, 0, "RSM_PROB UPDATER OP: learning rate, rsm decay and epsilon were not provided!"); + + double dLr, dRmsDecay, dEpsilon; + + if (block.width() > 2) { + const auto lr = INPUT_VARIABLE(2); + const auto rmsDecay = INPUT_VARIABLE(3); + const auto epsilon = INPUT_VARIABLE(4); + + REQUIRE_TRUE(lr->isScalar(), 0, "RSM_PROB UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + REQUIRE_TRUE(rmsDecay->isScalar(), 0, "RSM_PROB UPDATER OP: Rms decay has to be a scalar, but instead got rank %i!", rmsDecay->rankOf()); + REQUIRE_TRUE(epsilon->isScalar(), 0, "RSM_PROB UPDATER OP: Epsilon has to be a scalar, but instead got rank %i!", epsilon->rankOf()); + + dLr = lr->e(0); + dRmsDecay = rmsDecay->e(0); + dEpsilon = epsilon->e(0); + } + else { + dLr = T_ARG(0); + dRmsDecay = T_ARG(1); + dEpsilon = T_ARG(2); + } + + helpers::updaterRmsProp(block.launchContext(), *gradient, *initState, *update, *stateG, dLr, dRmsDecay, dEpsilon); + return Status::OK(); + } + + DECLARE_TYPES(rms_prop_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp new file mode 100644 index 000000000..491d7b53e --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/updaters/sgdUpdater.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + +#include +#include +#include +#include +#include + +namespace sd { + namespace ops { + + CONFIGURABLE_OP_IMPL(sgd_updater, 1, 1, true, 0, 0) { + + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + bool bLearningRate = 2 == block.width() || 1 == block.getTArguments()->size(); + + REQUIRE_TRUE(bLearningRate, 0, "SGD UPDATER OP: Learning rate was not provided!"); + + if (block.width() > 1) { + const auto lr = INPUT_VARIABLE(1); + REQUIRE_TRUE(lr->isScalar(), 0, "SGD UPDATER OP: Learning rate has to be a scalar, but instead got rank %i!", lr->rankOf()); + + input->applyScalarArr(scalar::Multiply, *lr, *output); + } + else { + input->applyScalar(scalar::Multiply, T_ARG(0), *output); + } + + return Status::OK(); + } + + DECLARE_TYPES(sgd_updater) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + } +} diff --git a/libnd4j/include/ops/declarable/headers/updaters.h b/libnd4j/include/ops/declarable/headers/updaters.h new file mode 100644 index 000000000..dc08ff1f2 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/updaters.h @@ -0,0 +1,210 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Oleh Semeniv (oleg.semeniv@gmail.com) + // + + +#ifndef LIBND4J_HEADERS_UPDATERS_H +#define LIBND4J_HEADERS_UPDATERS_H + +#include +#include +#include +#include +#include + + +namespace sd { + namespace ops { + + + /** + * SGD updater + * Input arrays: + * 0 - input array with gradients. + * Optional: + * 1 - scalar learning rate value + * Optional: + * T args + * 0 - scalar learning rate value + */ +#if NOT_EXCLUDED(OP_sgd_updater) + DECLARE_CONFIGURABLE_OP(sgd_updater, 1, 1, true, 0, 0); +#endif + + /** + * RmsPropUpdater updater + * Input arrays: + * 0 - input array with gradients. + * 1 - Initial state + * Optional: + * 2 - scalar learning rate value + * 3 - scalar rms decay + * 4 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - scalar rms decay + * 2 - epsilon + */ +#if NOT_EXCLUDED(OP_rms_prop_updater) + DECLARE_CONFIGURABLE_OP(rms_prop_updater, 2, 2, true, 0, 0); +#endif + // AdaGrad + /* Input arrays : + * 0 - input array with gradients. + * 1 - historical grad state + * Optional : + * 2 - scalar learning rate value + * 3 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - epsilon + */ +#if NOT_EXCLUDED(OP_ada_grad_updater) + DECLARE_CONFIGURABLE_OP(ada_grad_updater, 2, 2, true, 0, 0); +#endif + // AdaMax + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_ada_max_updater) + DECLARE_CONFIGURABLE_OP(ada_max_updater, 3, 3, true, 0, 0); +#endif + // Nesterov's momentum + /* Input arrays : + * 0 - input array with gradients. + * 1 - V grad state + * Optional : + * 2 - scalar learning rate value + * 3 - scalar momentum value + * Optional: + * T args + * 0 - learning rate value + * 1 - momentum value + */ +#if NOT_EXCLUDED(OP_nesterovs_updater) + DECLARE_CONFIGURABLE_OP(nesterovs_updater, 2, 2, true, 0, 0); +#endif + // Adam + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_adam_updater) + DECLARE_CONFIGURABLE_OP(adam_updater, 3, 3, true, 0, 0); +#endif + // AdaDelta + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - rho value + * 6 - epsilon + * Optional: + * T args + * 0 - rho + * 1 - epsilon + */ +#if NOT_EXCLUDED(OP_ada_delta_updater) + DECLARE_CONFIGURABLE_OP(ada_delta_updater, 3, 3, true, 0, 0); +#endif + // Nadam + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V + * 2 - gradient state M + * Optional : + * 3 - scalar learning rate value + * 4 - beta 1 value + * 5 - beta 2 value + * 6 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_nadam_updater) + DECLARE_CONFIGURABLE_OP(nadam_updater, 3, 3, true, 0, 0); +#endif + // AmsGrad + /* Input arrays : + * 0 - input array with gradients. + * 1 - gradient state V - sqrd gradients + * 2 - gradient state M - moving avg + * 3 - gradient state H - max + * Optional : + * 4 - scalar learning rate value + * 5 - beta 1 value + * 6 - beta 2 value + * 7 - epsilon + * Optional: + * T args + * 0 - scalar learning rate value + * 1 - beta 1 value + * 2 - beta 2 value + * 3 - epsilon + * Optional: + * I args + * 0 - iteration + */ +#if NOT_EXCLUDED(OP_ams_grad_updater) + DECLARE_CONFIGURABLE_OP(ams_grad_updater, 4, 4, true, 0, 0); +#endif +} +} + +#endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp new file mode 100644 index 000000000..e80018348 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaDelta.cpp @@ -0,0 +1,108 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaDeltaUpdater_(const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* initMsg = initStateMsg.bufferAsT(); + const T* initMsdx = initStateMsdx.bufferAsT(); + + T* up = update.bufferAsT(); + T* stMsg = stateMsg.bufferAsT(); + T* stMsdx = stateMsdx.bufferAsT(); + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + const T rhoT = (1 - rho); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateMsg.ews() && 1 == initStateMsg.ews() && 1 == stateMsdx.ews() && 1 == initStateMsdx.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateMsdx.ordering() && + stateMsdx.ordering() == initStateMsdx.ordering() && + stateMsdx.ordering() == initStateMsg.ordering() && stateMsg.ordering() == initStateMsg.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stMsg[i] = rho * initMsg[i] + grad[i] * grad[i] * rhoT; + + up[i] = grad[i] * (sd::math::nd4j_sqrt(initMsdx[i] + epsilon) / sd::math::nd4j_sqrt(stMsg[i] + epsilon)); + + stMsdx[i] = rho * initMsdx[i] + up[i] * up[i] * rhoT; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsg.getShapeInfo()); + bool bXStMsgSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsg.getShapeInfo()); + bool bXInMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateMsdx.getShapeInfo()); + bool bXStMsdxSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateMsdx.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < gradient.lengthOf(); i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(initStateMsg.getShapeInfo(), coords); + const auto stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stateMsg.getShapeInfo(), coords); + const auto initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(initStateMsdx.getShapeInfo(), coords); + const auto stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stateMsdx.getShapeInfo(), coords); + + + stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdater_, (gradient, initStateMsg, initStateMsdx, update, stateMsg, stateMsdx, dRho, dEpsilon), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp new file mode 100644 index 000000000..280597d31 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaGrad.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + st[i] = init[i] + grad[i] * grad[i]; + up[i] = (lr * grad[i]) / (math::nd4j_sqrt(st[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords); + + st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset]; + up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, + const double dLr, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp new file mode 100644 index 000000000..ae986f901 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdaMax.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adaMaxUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + //m = B_1 * m + (1-B_1)*grad + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[i] = sd::math::nd4j_max((beta2 * initU[i]), sd::math::nd4j_abs(grad[i])) + 1e-32; + + up[i] = stM[i] * epsilonT / stU[i]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords); + const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + //m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max((beta2 * initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; + + up[zOffset] = stM[stMOffset] * epsilonT / stU[stUOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp new file mode 100644 index 000000000..b8eab1e6f --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAdam.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void adamUpdater_(const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, + NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initU = initStateU.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stU = stateU.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + const T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + T epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateU.ews() && 1 == initStateU.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateU.ordering() && + stateU.ordering() == initStateU.ordering() && + stateU.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * (1 - beta1); + stU[i] = beta2 * initU[i] + grad[i] * grad[i] * (1 - beta2); + + up[i] = (stM[i] * epsilonT) / (sd::math::nd4j_sqrt(stU[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateU.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateU.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initUOffset = bXInVSame ? xOffset : shape::getOffset(initStateU.getShapeInfo(), coords); + const auto stUOffset = bXStVSame ? xOffset : shape::getOffset(stateU.getShapeInfo(), coords); + const auto initMOffset = bXInVSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / (sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdater_, (gradient, initStateU, initStateM, update, stateU, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp new file mode 100644 index 000000000..686c22cbe --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterAmsGrad.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void amsGradUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + const T* initH = initStateH.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + T* stH = stateH.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + T epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && + 1 == stateV.ews() && 1 == initStateV.ews() && 1 == stateH.ews() && 1 == initStateH.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && + stateM.ordering() == initStateM.ordering() && + stateM.ordering() == initStateH.ordering() && stateH.ordering() == initStateH.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + stM[i] = beta1 * initM[i] + grad[i] * mbeta1; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + stH[i] = sd::math::nd4j_max(initH[i], stV[i]); + + up[i] = epsilonT * stM[i] / (sd::math::nd4j_sqrt(stH[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + bool bXInHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateH.getShapeInfo()); + bool bXStHSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords); + const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + const auto initHOffset = bXInHSame ? xOffset : shape::getOffset(initStateH.getShapeInfo(), coords); + const auto stHOffset = bXStHSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords); + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdater_, (gradient, initStateV, initStateM, initStateH, update, stateV, stateM, stateH, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp new file mode 100644 index 000000000..82ade0f16 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNadam.cpp @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void nadamUpdater_(const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + + const T* grad = gradient.bufferAsT(); + const T* initV = initStateV.bufferAsT(); + const T* initM = initStateM.bufferAsT(); + + T* up = update.bufferAsT(); + T* stV = stateV.bufferAsT(); + T* stM = stateM.bufferAsT(); + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + const T mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + const T mbeta1 = (1 - beta1); + const T mbeta2 = (1 - beta2); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateM.ews() && 1 == initStateM.ews() && 1 == stateV.ews() && 1 == initStateV.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && + update.ordering() == stateV.ordering() && + stateV.ordering() == initStateV.ordering() && + stateV.ordering() == initStateM.ordering() && stateM.ordering() == initStateM.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + auto oneMinusBeta1Grad = grad[i] * mbeta1; + + stM[i] = beta1 * initM[i] + oneMinusBeta1Grad; + stV[i] = beta2 * initV[i] + grad[i] * grad[i] * mbeta2; + + up[i] = (lr * ((stM[i] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateV.getShapeInfo()); + bool bXStVSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + bool bXInMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initStateM.getShapeInfo()); + bool bXStMSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateM.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initVOffset = bXInVSame ? xOffset : shape::getOffset(initStateV.getShapeInfo(), coords); + const auto stVOffset = bXStVSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + const auto initMOffset = bXInMSame ? xOffset : shape::getOffset(initStateM.getShapeInfo(), coords); + const auto stMOffset = bXStMSame ? xOffset : shape::getOffset(stateM.getShapeInfo(), coords); + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stVOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdater_, (gradient, initStateV, initStateM, update, stateV, stateM, dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp new file mode 100644 index 000000000..82e21ace7 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterNesterovs.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +static void nesterovsUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateV.bufferAsT(); + + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + const T momentumT = (-momentum - 1); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateV.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateV.ordering() && stateV.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + T prevState = momentum * init[i]; + st[i] = prevState - lr * grad[i]; + up[i] = prevState + momentumT * st[i]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateV.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateV.getShapeInfo(), coords); + + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdater_, (gradient, initState, update, stateV, dLr, dMomentum), FLOAT_TYPES); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp new file mode 100644 index 000000000..a0b9f731e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/updaterRmsProp.cpp @@ -0,0 +1,91 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +template +static void rmsPropUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + const T* grad = gradient.bufferAsT(); + const T* init = initState.bufferAsT(); + + T* up = update.bufferAsT(); + T* st = stateG.bufferAsT(); + + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateG.ews() && 1 == initState.ews(); + bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateG.ordering() && stateG.ordering() == initState.ordering(); + + if (bEws1 && bSameOrdering) { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + st[i] = init[i] * rmsDecay + grad[i] * grad[i] * (1 - rmsDecay) ; + up[i] = (lr * grad[i]) / ( math::nd4j_sqrt(st[i]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; + } + + bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo()); + bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo()); + bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateG.getShapeInfo()); + + auto func = PRAGMA_THREADS_FOR{ + + int coords[MAX_RANK]; + for (auto i = start; i < stop; i++) { + shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords); + const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords); + const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords); + const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateG.getShapeInfo(), coords); + + st[stOffset] = init[initOffset] * rmsDecay + grad[xOffset] * grad[xOffset] * (1 - rmsDecay) ; + up[zOffset] = (lr * grad[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); + } + }; + + samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1); + return; +} + +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdater_, (gradient, initState, update, stateG, dLr, dRmsDecay, dEpsilon), FLOAT_TYPES); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu new file mode 100644 index 000000000..33272ff57 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaDelta.cu @@ -0,0 +1,129 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaDeltaUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinMsg, const Nd4jLong* inMsgShapeInfo, + const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, + const Nd4jLong* stMsgShapeInfo, void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const T rho, const T epsilon) { + + const auto grad = reinterpret_cast(vx); + const auto initMsg= reinterpret_cast(vinMsg); + const auto initMsdx = reinterpret_cast(vinMsdx); + + auto up = reinterpret_cast(vz); + auto stMsg = reinterpret_cast(vstMsg); + auto stMsdx = reinterpret_cast(vstMsdx); + + __shared__ Nd4jLong xLen; + __shared__ T rhoT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInMsgSame, bXStMsgSame, bXInMsdxSame, bXStMsdxSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + rhoT = (1 - rho); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stMsgShapeInfo) && 1 == shape::elementWiseStride(inMsgShapeInfo) && + 1 == shape::elementWiseStride(stMsdxShapeInfo) && 1 == shape::elementWiseStride(inMsdxShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stMsgShapeInfo) && + shape::order(stMsgShapeInfo) == shape::order(inMsgShapeInfo) && shape::order(inMsgShapeInfo) == shape::order(stMsdxShapeInfo) && + shape::order(stMsdxShapeInfo) == shape::order(inMsdxShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsgShapeInfo); + bXStMsgSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsgShapeInfo); + bXInMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, inMsdxShapeInfo); + bXStMsdxSame = shape::haveSameShapeAndStrides(xShapeInfo, stMsdxShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMsgOffset = i, initMsdxOffset = i, stMsgOffset = i, stMsdxOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMsgOffset = bXInMsgSame ? xOffset : shape::getOffset(inMsgShapeInfo, coords); + stMsgOffset = bXStMsgSame ? xOffset : shape::getOffset(stMsgShapeInfo, coords); + initMsdxOffset = bXInMsdxSame ? xOffset : shape::getOffset(inMsdxShapeInfo, coords); + stMsdxOffset = bXStMsdxSame ? xOffset : shape::getOffset(stMsdxShapeInfo, coords); + } + + stMsg[stMsgOffset] = rho * initMsg[initMsgOffset] + grad[xOffset] * grad[xOffset] * rhoT; + + up[zOffset] = grad[xOffset] * (sd::math::nd4j_sqrt(initMsdx[initMsdxOffset] + epsilon) / sd::math::nd4j_sqrt(stMsg[stMsgOffset] + epsilon)); + + stMsdx[stMsdxOffset] = rho * initMsdx[initMsdxOffset] + up[zOffset] * up[zOffset] * rhoT; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaDeltaUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinMsg, const Nd4jLong* inMsgShapeInfo, const void* vinMsdx, const Nd4jLong* inMsdxShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstMsg, const Nd4jLong* stMsgShapeInfo, + void* vstMsdx, const Nd4jLong* stMsdxShapeInfo, const double dRho, const double dEpsilon) { + + const T rho = static_cast(dRho); + const T epsilon = static_cast(dEpsilon); + + adaDeltaUpdaterCuda << > > (vx, xShapeInfo, vinMsg, inMsgShapeInfo, + vinMsdx, inMsdxShapeInfo, vz, zShapeInfo, vstMsg, stMsgShapeInfo, vstMsdx, stMsdxShapeInfo, rho, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, + NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon) { + + PointersManager manager(context, "adaDeltaUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaDeltaUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateMsg.getSpecialBuffer(), initStateMsg.getSpecialShapeInfo(), initStateMsdx.getSpecialBuffer(), initStateMsdx.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(),stateMsg.getSpecialBuffer(), stateMsg.getSpecialShapeInfo(), + stateMsdx.getSpecialBuffer(), stateMsdx.getSpecialShapeInfo(), dRho, dEpsilon), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateMsg, &stateMsdx }, { &gradient, &initStateMsg, &initStateMsdx }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu new file mode 100644 index 000000000..f0e77826d --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaGrad.cu @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const T lr, const T epsilon) { + + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + __shared__ Nd4jLong xLen; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + st[stOffset] = init[initOffset] + x[xOffset] * x[xOffset]; + up[zOffset] = (lr * x[xOffset]) / (math::nd4j_sqrt(st[stOffset]) + epsilon); + + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dEpsilon) { + + const T lr = static_cast(dLr); + const T epsilon = static_cast(dEpsilon); + + adaGradUpdaterCuda << > > (vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) { + + PointersManager manager(context, "adaGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateH }, { &gradient, &initState }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dEpsilon), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateH }, { &gradient, &initState }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu new file mode 100644 index 000000000..514440304 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdaMax.cu @@ -0,0 +1,142 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adaMaxUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T beta1T, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + beta1T = sd::math::nd4j_pow(beta1, (iteration + 1) ); + + epsilonT = lr / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(stmShapeInfo) && + shape::order(xShapeInfo) == shape::order(inmShapeInfo) && shape::order(xShapeInfo) == shape::order(invShapeInfo) && + shape::order(xShapeInfo) == shape::order(stvShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + //m = B_1 * m + (1-B_1)*grad + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + //u = max(B_2 * u, |grad|) + stU[stUOffset] = sd::math::nd4j_max( (beta2* initU[initUOffset]), sd::math::nd4j_abs(grad[xOffset])) + 1e-32; + + up[zOffset] = (stM[stMOffset] * epsilonT) / stU[stUOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adaMaxUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, + const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + adaMaxUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, vz, + zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, + const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "adaMaxUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), adaMaxUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), + gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), initStateU.getSpecialBuffer(), + initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), + stateU.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), + dLr, dBeta1, dBeta2, dEpsilon, nIteration ), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu new file mode 100644 index 000000000..e23f4a5ca --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAdam.cu @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, + const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, + const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initU = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stU = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + T beta1T = sd::math::nd4j_pow(beta1, (iteration + 1)); + T beta2T = sd::math::nd4j_pow(beta2, (iteration + 1)); + + epsilonT = lr * sd::math::nd4j_sqrt(1. - beta2T) / (1.0 - beta1T); + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1); + stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2); + + up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt(stU[stUOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, + void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + adamUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, + NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, + const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "adamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateU.getSpecialBuffer(), initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), stateU.getSpecialShapeInfo(), + stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + + NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu new file mode 100644 index 000000000..d24c83f17 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterAmsGrad.cu @@ -0,0 +1,152 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void amsGradUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, const void* vinh, const Nd4jLong* inhShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, void* vstH, const Nd4jLong* sthShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + const auto initH = reinterpret_cast(vinh); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + auto stH = reinterpret_cast(vstH); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1, mbeta2, epsilonT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame, bXInHSame, bXStHSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + epsilonT = lr * sd::math::nd4j_sqrt(1.0 - sd::math::nd4j_pow(beta2, (iteration + 1))) / (1.0 - sd::math::nd4j_pow(beta1, (iteration + 1))); + + if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT)) + epsilonT = epsilon; + + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo) && + 1 == shape::elementWiseStride(sthShapeInfo) && 1 == shape::elementWiseStride(inhShapeInfo); + + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo) && shape::order(invShapeInfo) == shape::order(sthShapeInfo) && + shape::order(sthShapeInfo) == shape::order(inhShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + bXInHSame = shape::haveSameShapeAndStrides(xShapeInfo, inhShapeInfo); + bXStHSame = shape::haveSameShapeAndStrides(xShapeInfo, sthShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initVOffset = i, initHOffset = i, stMOffset = i, stVOffset = i, stHOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + initVOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stVOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initHOffset = bXInHSame ? xOffset : shape::getOffset(inhShapeInfo, coords); + stHOffset = bXStHSame ? xOffset : shape::getOffset(sthShapeInfo, coords); + } + + stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * mbeta1; + stV[stVOffset] = beta2 * initV[initVOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + stH[stHOffset] = sd::math::nd4j_max(initH[initHOffset], stV[stVOffset]); + + up[zOffset] = epsilonT * stM[stMOffset] / (sd::math::nd4j_sqrt(stH[stHOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void amsGradUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + const void* vinh, const Nd4jLong* inhShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + void* vstH, const Nd4jLong* sthShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + amsGradUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vinh, inhShapeInfo, vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, vstH, sthShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, + NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "amsGradUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV, &stateM, &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), amsGradUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + initStateH.getSpecialBuffer(), initStateH.getSpecialShapeInfo(), update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), + stateH.getSpecialBuffer(), stateH.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV, &stateM , &stateH }, { &gradient, &initStateV, &initStateM, &initStateH }); + + manager.synchronize(); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu new file mode 100644 index 000000000..2ac1ec99b --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNadam.cu @@ -0,0 +1,137 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void nadamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, + const void* vinm, const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo, + const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) { + + const auto grad = reinterpret_cast(vx); + const auto initV = reinterpret_cast(vinv); + const auto initM = reinterpret_cast(vinm); + + auto up = reinterpret_cast(vz); + auto stV = reinterpret_cast(vstV); + auto stM = reinterpret_cast(vstM); + + __shared__ Nd4jLong xLen; + __shared__ T mbeta1T, mbeta1, mbeta2; + __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + + mbeta1T = 1.0 - sd::math::nd4j_pow(beta1, (iteration + 1)); + mbeta1 = (1 - beta1); + mbeta2 = (1 - beta2); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) && + 1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) && + shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) && + shape::order(stvShapeInfo) == shape::order(invShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo); + bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo); + bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo); + bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i; + + if (!bEWS || !bOrdering){ + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords); + stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords); + initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords); + stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords); + } + + auto oneMinusBeta1Grad = grad[xOffset] * mbeta1; + + stM[stMOffset] = beta1 * initM[initMOffset] + oneMinusBeta1Grad; + stV[stUOffset] = beta2 * initV[initUOffset] + grad[xOffset] * grad[xOffset] * mbeta2; + + up[zOffset] = (lr * ((stM[stMOffset] * beta1 + oneMinusBeta1Grad) / mbeta1T)) / (sd::math::nd4j_sqrt(stV[stUOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void nadamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, + const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, void* vstM, + const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + const T lr = static_cast(dLr); + const T beta1 = static_cast(dBeta1); + const T beta2 = static_cast(dBeta2); + const T epsilon = static_cast(dEpsilon); + const T iteration = static_cast(nIteration); + + nadamUpdaterCuda << > > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, + vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration); +} + +/////////////////////////////////////////////////////////////////// +void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, + NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) { + + PointersManager manager(context, "nadamUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), nadamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initStateV.getSpecialBuffer(), initStateV.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), + stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(), dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV, &stateM }, { &gradient, &initStateV, &initStateM }); + + manager.synchronize(); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu new file mode 100644 index 000000000..73616a5cd --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterNesterovs.cu @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + + +/////////////////////////////////////////////////////////////////// +template +__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) { + + const auto grad = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ T momentumT; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + xLen = shape::length(xShapeInfo); + momentumT = (-momentum - 1); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) && + shape::order(xShapeInfo) == shape::order(stShapeInfo); + + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + T prevState = momentum * init[initOffset]; + st[stOffset] = prevState - lr * grad[xOffset]; + up[zOffset] = prevState + momentumT * st[stOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dMomentum) { + + const T lr = static_cast(dLr); + const T momentum = static_cast(dMomentum); + nesterovsUpdaterCuda << > > (vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, momentum); +} + +/////////////////////////////////////////////////////////////////// +void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, + NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) { + + PointersManager manager(context, "nesterovsUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState }); + BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateV.getSpecialBuffer(), stateV.getSpecialShapeInfo(), dLr, dMomentum), FLOAT_TYPES); + NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState }); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu new file mode 100644 index 000000000..de0a5dba1 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/updaterRmsProp.cu @@ -0,0 +1,121 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ void rmsPropUpdaterCuda(const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const T lr, const T rmsDecay, const T epsilon) { + + const auto x = reinterpret_cast(vx); + const auto init = reinterpret_cast(vin); + + auto up = reinterpret_cast(vz); + auto st = reinterpret_cast(vst); + + __shared__ Nd4jLong xLen; + __shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame; + + if (threadIdx.x == 0) { + + xLen = shape::length(xShapeInfo); + + bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) && + 1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo); + + bOrdering = shape::order(zShapeInfo) == shape::order(xShapeInfo) && shape::order(xShapeInfo) == shape::order(stShapeInfo) && + shape::order(xShapeInfo) == shape::order(inShapeInfo); + bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo); + bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo); + } + __syncthreads(); + + int coords[MAX_RANK]; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { + + auto xOffset = i, zOffset = i, initOffset = i, stOffset = i; + + if (!bEWS || !bOrdering) { + + shape::index2coords(i, xShapeInfo, coords); + xOffset = shape::getOffset(xShapeInfo, coords); + zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords); + initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords); + stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords); + } + + st[stOffset] = init[initOffset] * rmsDecay + x[xOffset] * x[xOffset] * (1 - rmsDecay) ; + up[zOffset] = (lr * x[xOffset]) / ( math::nd4j_sqrt(st[stOffset]) + epsilon); + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void rmsPropUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void *vx, const Nd4jLong *xShapeInfo, const void *vin, const Nd4jLong *inShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + const T lr = static_cast(dLr); + const T rmsDecay = static_cast(dRmsDecay); + const T epsilon = static_cast(dEpsilon); + + rmsPropUpdaterCuda<<>>(vx, xShapeInfo, vin, inShapeInfo, + vz, zShapeInfo, vst, stShapeInfo, lr, rmsDecay, epsilon); +} + +/////////////////////////////////////////////////////////////////// +void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, + const double dLr, const double dRmsDecay, const double dEpsilon) { + + PointersManager manager(context, "rmsPropUpdater"); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&update, &stateG}, {&gradient, &initState }); + + BUILD_SINGLE_SELECTOR(gradient.dataType(), rmsPropUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, + context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(), + initState.getSpecialBuffer(), initState.getSpecialShapeInfo(), + update.getSpecialBuffer(), update.getSpecialShapeInfo(), + stateG.getSpecialBuffer(), stateG.getSpecialShapeInfo(), + dLr, dRmsDecay, dEpsilon ), FLOAT_TYPES); + + NDArray::registerSpecialUse({&update, &stateG}, {&gradient, &initState}); + + manager.synchronize(); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/updatersHelpers.h b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h new file mode 100644 index 000000000..5bd89b487 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/updatersHelpers.h @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#ifndef LIBND4J_UPDATER_RMS_PROM_H +#define LIBND4J_UPDATER_RMS_PROM_H + +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + + void updaterRmsProp(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateG, const double dLr, const double dRmsDecay, const double dEpsilon); + void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon); + void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateV, const double dLr, const double bMomentum); + void updaterAdaMax(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, NDArray& update, NDArray& stateU, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAdaDelta(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateMsg, const NDArray& initStateMsdx, NDArray& update, NDArray& stateMsg, NDArray& stateMsdx, const double dRho, const double dEpsilon); + void updaterNadam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, NDArray& update, NDArray& stateV, NDArray& stateM, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + void updaterAmsGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateV, const NDArray& initStateM, const NDArray& initStateH, NDArray& update, NDArray& stateV, NDArray& stateM, NDArray& stateH, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration); + +} +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp index f8de783c9..b1cafa073 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -187,3 +187,1232 @@ TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(output.equalsTo(exp)); } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.001f); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient, &lr }, { &gradient }, {}, { }); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd2) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient }, { &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd3) { + + NDArray gradientC('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray updateC('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + NDArray gradient('f', { 1, 5 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + gradient.assign(gradientC); + update.assign(updateC); + + sd::ops::sgd_updater op; + + auto results = op.evaluate({ &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm1) { + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto decay = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &decay, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472121903197142, 0.4472095514452829, 0.4472135169488324, 0.44721352981195367, 0.44721349127249754 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.00164065126484513, 0.00055124687044416, 0.03816546608068996, 0.04711672627124962, 0.02749591463177582 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + + NDArray grad1('c', { 1, 5 }, { 0.0139725673943758, 0.19333727657794952, 0.9288347363471985, 0.9253600239753723, 0.3578299283981323 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init, &lr, &decay, &epsilon }, { &grad1, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.03528177364993147, 0.3952537075263024, 0.32964378302079766, 0.31269398966616074, 0.1984174163852542 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.00156838033358239, 0.00239264965265088, 0.07939389114891399, 0.08757544865627226, 0.03252323178305766 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + NDArray grad2('c', { 1, 5 }, { 0.5442887544631958, 0.5386605262756348, 0.884294331073761, 0.15599730610847473, 0.7259345054626465 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init, &lr, &decay, &epsilon }, { &grad2, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.4262874753567082, 0.41582357367557454, 0.2613066321005825, 0.05369221235564697, 0.3034061716240995 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.01630247372865814, 0.01678077529839554, 0.11452301978992785, 0.0844134341991137, 0.05724611550496966 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::rms_prop_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaGrad1) { + + // need Java test + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::ada_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs1) { + + NDArray grad0('c', { 1, 5 }, { 0.6877592206001282, 0.7830561399459839, 0.7647699117660522, 0.6183066964149475, 0.3303879499435425 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init }, { &grad0, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.13067425191402435, 0.14878066658973696, 0.14530628323554992, 0.11747827231884002, 0.06277371048927306 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.06877592206001282, -0.0783056139945984, -0.07647699117660522, -0.06183066964149475, -0.03303879499435425 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + NDArray grad1('c', { 1, 5 }, { 0.3676236569881439, 0.07645636051893234, 0.45949840545654297, 0.6335387825965881, 0.2953402101993561 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init }, { &grad1, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.12555699169635773, 0.07795425583422186, 0.14925105988979342, 0.17045521110296247, 0.08287606388330458 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.09866069555282593, -0.0781206886470318, -0.11477913260459902, -0.11900148093700408, -0.05926893651485443 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + NDArray grad2('c', { 1, 5 }, { 0.9874004125595093, 0.41817641258239746, 0.16838215291500092, 0.00803728867322206, 0.37015461921691895 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init }, { &grad2, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.26752124178409575, 0.1427312761947513, 0.12496370646357537, 0.09791828440688549, 0.11833721622824667 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.18753466725349427, -0.11212626104056837, -0.12013943463563921, -0.10790506171062587, -0.09035750478506088 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto momentum = NDArrayFactory::create(0.9f); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::nesterovs_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.9f }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax2) { + + NDArray grad0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.00538735911250114, 0.09700437784194944, 0.08912011384963987, 0.08891847729682921, 0.01882378011941909 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6400517821311951, 0.3779360353946686, 0.35128724575042725, 0.6554615497589111, 0.8420050740242004 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00107575360832691, 0.00129089809294599, 0.00129546826560191, 0.00163878765669416, 0.00120120308808246 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.6400517821311951, 0.9690737346410752, 0.8903099373579025, 0.888295588195324, 0.8420050740242004 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.06885380141437052, 0.12509754359722136, 0.11533682703971859, 0.1455727845430374, 0.10114190950989721 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.5984494686126709, 0.05978915095329285, 0.5749519467353821, 0.2804091274738312, 0.0192152876406908 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00190508497658779, 0.00122473022928962, 0.00181352349370876, 0.00179237223044249, 0.00110500865710834 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.6394117303490638, 0.9681046609064341, 0.8894196274205446, 0.8874072926071286, 0.8411630689501762 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.12181336813420054, 0.11856670433282851, 0.16129833900928492, 0.15905641883611676, 0.09294924732297657 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::ada_max_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateM.assign(stateM1C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateM.assign(stateM2C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam2) { + + NDArray grad0('c', { 1, 5 }, { 0.7124611735343933, 0.7283763289451599, 0.8196553587913513, 0.9501070976257324, 0.2654055953025818 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999955614757, 0.00099999956584582, 0.00099999961419438, 0.0009999996671663, 0.00099999880851273 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.00050760092379401, 0.00053053207656763, 0.00067183490719538, 0.00090270349695879, 0.00007044013001792 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.07124611735343932, 0.07283763289451597, 0.08196553587913512, 0.09501070976257323, 0.02654055953025817 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.4374369978904724, 0.11488933861255646, 0.6765823364257812, 0.7659900188446045, 0.04410457238554955 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00129067017716555, 0.00104532555849556, 0.00133106720937621, 0.00132869584719374, 0.00105226561254395 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.00069844444999364, 0.00054320110461789, 0.00112892673025155, 0.00148854150243139, 0.00007231490319321 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.10786520540714262, 0.07704280346632002, 0.14142721593379973, 0.16210864067077635, 0.02829696081578731 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.496029257774353, 0.11621368676424026, 0.9112075567245483, 0.5717480182647705, 0.5975669026374817 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00150986322036664, 0.00108559662275258, 0.00156079502787382, 0.00150778241516558, 0.00130066803775601 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.00094379103011182, 0.00055616352450461, 0.00195809701495322, 0.00181394875731865, 0.00042932879141777 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.14668161064386365, 0.08095989179611204, 0.21840525001287456, 0.20307257843017573, 0.08522395499795674 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::adam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta2) { + + NDArray grad0('c', { 1, 5 }, { 0.22060230374336243, 0.10593396425247192, 0.9027279019355774, 0.831809401512146, 0.2733047902584076 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto rho = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.0e-6); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initMsg, &initMsdx, &rho, &epsilon }, { &grad0, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.0044712172817412, 0.00446815612502933, 0.00447208107763182, 0.004472071321461, 0.00447153735969189 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.00243326882084394, 0.0005611002391122, 0.04074588324665051, 0.03459534402219976, 0.00373477541890961 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.00000099958919903, 0.00000099822095788, 0.00000099997545825, 0.00000099997109521, 0.00000099973231796 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + NDArray grad1('c', { 1, 5 }, { 0.6351608633995056, 0.21878601610660553, 0.6470938920974731, 0.3742971122264862, 0.9453978538513184 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initMsg, &initMsdx, &rho, &epsilon }, { &grad1, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00598985959779411, 0.00571609509028959, 0.00374704195122062, 0.00265092283150538, 0.00608704322078556 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.02248307149952203, 0.00292641126934659, 0.05964511434381081, 0.03987049323214412, 0.0482368917512981 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000274353063914, 0.00000258199706405, 0.00000165199285454, 0.00000130134213338, 0.00000280235046064 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + NDArray grad2('c', { 1, 5 }, { 0.8484492301940918, 0.9634076952934265, 0.6676893830299377, 0.4450211524963379, 0.32364124059677124 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initMsg, &initMsdx, &rho, &epsilon }, { &grad2, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00685468722145889, 0.00822128238053265, 0.00386965914609878, 0.00308849888680941, 0.00279277397245112 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.05735222273539331, 0.04918781007340889, 0.07895331423716523, 0.04777915987899536, 0.05106222979448406 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.00000495569095238, 0.00000583237140987, 0.00000231810630717, 0.0000017132162954, 0.00000305221226067 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msg + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msdx + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsg('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsdx('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initMsg.assign(initVC); + initMsdx.assign(initMC); + + sd::ops::ada_delta_updater op; + auto results = op.evaluate({ &grad, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsg('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + NDArray stateMsdx('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateMsg.assign(stateV0C); + stateMsdx.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, results.at(1), results.at(2) }, { 0.95, 1.0e-6 }, { }); + + NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + + NDArray stateV1C('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + update.assign(update1C); + stateMsg.assign(stateV1C); + stateMsdx.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateMsg, &stateMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + update.assign(update2C); + stateMsg.assign(stateV2C); + stateMsdx.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam2) { + + NDArray grad0('c', { 1, 5 }, { 0.8047558665275574, 0.9653639197349548, 0.31240877509117126, 0.9530212879180908, 0.01295729912817478 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325193356386, 0.0600832558615088, 0.06008321472550684, 0.06008325560661022, 0.0600818092240132 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00064763200471052, 0.00093192749752604, 0.00009759924275397, 0.00090824957522506, 0.0000001678916007 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.08047558665275573, 0.09653639197349546, 0.03124087750911712, 0.09530212879180906, 0.00129572991281748 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.9839006662368774, 0.8964805603027344, 0.3631269931793213, 0.00931886397302151, 0.6320028901100159 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06273730114378717, 0.0596708938019245, 0.06226533928512862, 0.02621380498466489, 0.06059567064824535 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00161504489372718, 0.00173467296502922, 0.00022936285668667, 0.00090742816687558, 0.0003995953768165 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.17081809461116787, 0.17653080880641933, 0.06442948907613753, 0.08670380230993031, 0.06436644593253729 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.7712154984474182, 0.1282273381948471, 0.7019220590591431, 0.8883536458015442, 0.33057701587677 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06062658222261493, 0.04001212712739213, 0.06906390273197544, 0.05804376499107734, 0.05097529565845974 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00220820319387896, 0.00174938054232472, 0.00072182807082381, 0.0016956929387176, 0.00050847694486568 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2308578349947929, 0.1717004617452621, 0.12817874607443808, 0.16686878665909166, 0.09098750292696056 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::nadam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad2) { + + NDArray grad0('c', { 1, 5 }, { 0.5730348229408264, 0.04330538213253021, 0.249028742313385, 0.6514443755149841, 0.7017051577568054 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999944815292, 0.00099999269777932, 0.00099999873015716, 0.00099999951457465, 0.00099999954934402 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.0004923901284225 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.00049239012842255 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.05730348229408263, 0.00433053821325302, 0.0249028742313385, 0.0651444375514984, 0.07017051577568052 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6404328346252441, 0.9432603120803833, 0.45608729124069214, 0.9097326993942261, 0.748093843460083 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134565543815267, 0.00104022434054697, 0.00130914539820157, 0.00133725290576052, 0.0013453914974122 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.11561641752719877, 0.09822351559996603, 0.06802131593227385, 0.14960326373577115, 0.13796284854412078 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.46250319480895996, 0.09698919206857681, 0.21754667162895203, 0.46824514865875244, 0.6005083918571472 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00154098993679222, 0.00103399135000281, 0.00147364850040774, 0.00149693641196572, 0.00155078467854623 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.1503050952553749, 0.09810008324682712, 0.08297385150194167, 0.1814674522280693, 0.1842174028754234 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initHC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + NDArray initH('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + initH.assign(initHC); + + sd::ops::ams_grad_updater op; + auto results = op.evaluate({ &grad, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateH0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + stateH.assign(stateH0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + NDArray stateH1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + stateH.assign(stateH1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + NDArray stateH2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + stateH.assign(stateH2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 327e3c52e..ebe27bd85 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -43,6 +43,15 @@ public class ImportClassMapping { private static final List> fnClasses = Arrays.>asList( org.nd4j.linalg.api.ops.DynamicCustomOp.class, org.nd4j.linalg.api.ops.NoOp.class, + org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater.class, + org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater.class, org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java new file mode 100644 index 000000000..db87ad5e4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaDeltaUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaDeltaUpdater extends DynamicCustomOp { + + public AdaDeltaUpdater() { + // + } + + public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, double rho, double epsilon) { + this(gradients, stateMsg, stateMsdx, gradients, stateMsg, stateMsdx, rho, epsilon); + } + + public AdaDeltaUpdater(@NonNull INDArray gradients, @NonNull INDArray stateMsg, @NonNull INDArray stateMsdx, @NonNull INDArray updates, @NonNull INDArray updatedStateMsg, @NonNull INDArray updatedStateMsdx, double rho, double epsilon) { + addInputArgument(gradients, stateMsg, stateMsdx); + addOutputArgument(updates, updatedStateMsg, updatedStateMsdx); + addTArgument(rho, epsilon); + } + + @Override + public String opName() { + return "ada_delta_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java new file mode 100644 index 000000000..e2304bdfb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaGradUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaGradUpdater extends DynamicCustomOp { + + public AdaGradUpdater() { + // + } + + public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double epsilon) { + this(gradients, state, gradients, state, lr, epsilon); + } + + public AdaGradUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double epsilon) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, epsilon); + } + + @Override + public String opName() { + return "ada_grad_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java new file mode 100644 index 000000000..483078335 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdaMaxUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdaMaxUpdater extends DynamicCustomOp { + + public AdaMaxUpdater() { + // + } + + public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public AdaMaxUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateU, stateM); + addOutputArgument(updates, updatedStateU, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "ada_max_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java new file mode 100644 index 000000000..1ab34ae52 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AdamUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AdamUpdater extends DynamicCustomOp { + + public AdamUpdater() { + // + } + + public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateU, stateM, gradients, stateU, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public AdamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateU, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateU, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateU, stateM); + addOutputArgument(updates, updatedStateU, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "adam_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java new file mode 100644 index 000000000..35af113ad --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class AmsGradUpdater extends DynamicCustomOp { + + public AmsGradUpdater() { + // + } + + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); + } + + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateV, stateM, stateH); + addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "ams_grad_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java new file mode 100644 index 000000000..ad4f374b7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class NadamUpdater extends DynamicCustomOp { + + public NadamUpdater() { + // + } + + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); + } + + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + addInputArgument(gradients, stateV, stateM); + addOutputArgument(updates, updatedStateV, updatedStateM); + addTArgument(lr, beta1, beta2, epsilon); + addIArgument(iteration); + } + + @Override + public String opName() { + return "nadam_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java new file mode 100644 index 000000000..a277f750f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NesterovsUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class NesterovsUpdater extends DynamicCustomOp { + + public NesterovsUpdater() { + // + } + + public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double momentum) { + this(gradients, state, gradients, state, lr, momentum); + } + + public NesterovsUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double momentum) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, momentum); + } + + @Override + public String opName() { + return "nesterovs_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java new file mode 100644 index 000000000..aaf734ea8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/RmsPropUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class RmsPropUpdater extends DynamicCustomOp { + + public RmsPropUpdater() { + // + } + + public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, double lr, double decay, double epsilon) { + this(gradients, state, gradients, state, lr, decay, epsilon); + } + + public RmsPropUpdater(@NonNull INDArray gradients, @NonNull INDArray state, @NonNull INDArray updates, @NonNull INDArray updatedState, double lr, double decay, double epsilon) { + addInputArgument(gradients, state); + addOutputArgument(updates, updatedState); + addTArgument(lr, decay, epsilon); + } + + @Override + public String opName() { + return "rms_prop_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java new file mode 100644 index 000000000..ef40735a4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/SgdUpdater.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.updaters; + +import lombok.NonNull; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * + * @author raver119@gmail.com + */ +public class SgdUpdater extends DynamicCustomOp { + + public SgdUpdater() { + // + } + + public SgdUpdater(@NonNull INDArray input, double lr) { + this(input, input, lr); + } + + public SgdUpdater(@NonNull INDArray input, @NonNull INDArray output, double lr) { + addInputArgument(input); + addOutputArgument(output); + addTArgument(lr); + } + + @Override + public String opName() { + return "sgd_updater"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 17bf95031..33260da70 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -10686,6 +10686,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 80d5904a6..47791f865 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -12422,6 +12422,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java index 660b178e4..4a4d6aab6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java @@ -15,10 +15,12 @@ ******************************************************************************/ package org.nd4j.linalg.learning; +import lombok.val; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.learning.config.*; @@ -58,14 +60,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val msgu = msg.dup(); + val msdxu = msdx.dup(); UpdaterJavaCode.applyAdaDeltaUpdater(g1, msg, msdx, rho, epsilon); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(g3, msgu, msdxu, rho, epsilon)); + assertEquals(msg, state.get("msg")); assertEquals(msdx, state.get("msdx")); assertEquals(g1, g2); + + assertEquals(msg, msgu); + assertEquals(msdx, msdxu); + assertEquals(g1, g3); } } @@ -85,13 +96,20 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val su = s.dup(); UpdaterJavaCode.applyAdaGradUpdater(g1, s, lr, epsilon); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(g3, su, lr, epsilon)); + assertEquals(s, state.get("grad")); assertEquals(g1, g2); + + assertEquals(s, su); + assertEquals(g1, g3); } } @@ -118,14 +136,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); UpdaterJavaCode.applyAdamUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -150,14 +177,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); UpdaterJavaCode.applyAdaMaxUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -185,15 +221,26 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val mu = m.dup(); + val vu = v.dup(); + val hu = vH.dup(); UpdaterJavaCode.applyAmsGradUpdater(g1, m, v, vH, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new AmsGradUpdater(g3, vu, mu, hu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(vH, state.get("V_HAT")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(vH, hu); + assertEquals(g1, g3); } } @@ -219,14 +266,23 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val vu = v.dup(); + val mu = m.dup(); UpdaterJavaCode.applyNadamUpdater(g1, m, v, lr, beta1, beta2, eps, i); u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(g3, vu, mu, lr, beta1, beta2, eps, i)); + assertEquals(m, state.get("M")); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(m, mu); + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -247,13 +303,18 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val vu = v.dup(); UpdaterJavaCode.applyNesterovsUpdater(g1, v, lr, momentum); - u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(g3, vu, lr, momentum)); assertEquals(v, state.get("V")); assertEquals(g1, g2); + + assertEquals(v, vu); + assertEquals(g1, g3); } } @@ -275,13 +336,19 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); + val gu = g.dup(); UpdaterJavaCode.applyRmsProp(g1, g, lr, decay, eps); - u.applyUpdater(g2, i, 0); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(g3, gu, lr,decay, eps)); assertEquals(g, state.get("G")); assertEquals(g1, g2); + + assertEquals(g, gu); + assertEquals(g1, g3); + } } @@ -294,11 +361,14 @@ public class UpdaterValidation extends BaseNd4jTest { for( int i=0; i<3; i++ ) { INDArray g1 = Nd4j.linspace(DataType.DOUBLE, 1, 5, 1).reshape(1,5); INDArray g2 = g1.dup(); + val g3 = g1.dup(); UpdaterJavaCode.applySgd(g1, lr); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(g3, lr)); u.applyUpdater(g2, i, 0); assertEquals(g1, g2); + assertEquals(g1, g3); } }