Oleh 69c92ca5ae
Learning updaters for gradient (#335)
* libnd4j raw implementation of sgd upader

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections and simple test added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections after discussion

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j integrate applyScalar

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j raw implementation of rmsPropUpdater on cpu

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fix operations declaration

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j rmsPropUpdater added, test cases for sgd, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some fixes and improvements for rmsPropUpdater based on Java tests

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed cuda implementation, update tests and corrected behavior according java tests

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j adaGrad updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one minor fix for ada grad

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several more fixes for ada_grad

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j nesterovs updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed nesterovs updater behavior, several typos and rename file

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one minor typo

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j ada max updater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos in adaMax updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed several typos in adaMaxUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for adaMax, added Adam Updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j adaDeltaUpdater added, minor fixes for adamUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for adaDeltaUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j nadamUpdater added

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j one more correction for nadam updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several fixes for nadam updater and added amsGradUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j several typos fixed in amsGradUpdater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections and added f order support rmsProp updater

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added support of f order for all updaters and modify tests for testing in place

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed issues for updates when not in place mode used, added tests for f order

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j added input shape checks

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some corrections for different cases handling

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j some code clean up and optimize per request

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j updaters refactoring after review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* SgdUpdater wrapper

Signed-off-by: raver119 <raver119@gmail.com>

* first test

Signed-off-by: raver119 <raver119@gmail.com>

* RmsPropUpdater added

Signed-off-by: raver119 <raver119@gmail.com>

* NadamUpdater + NesterovsUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AmsGradUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AdamUpdater added

Signed-off-by: raver119 <raver119@gmail.com>

* AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater

Signed-off-by: raver119 <raver119@gmail.com>

* AdaGradUpdater test added

Signed-off-by: raver119 <raver119@gmail.com>

* libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j next step to split operations implementation into separate files

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j merge master and minor corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j revert some changes of split implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j forgot to add header file

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* public default constructors

Signed-off-by: raver119 <raver119@gmail.com>

* ImportClassMapping updated

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
2020-03-23 07:28:31 +03:00

92 lines
3.7 KiB
C++

/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/helpers/updatersHelpers.h>
#include <execution/Threads.h>
#include <math/platformmath.h>
#include <math/templatemath.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static void adaGradUpdater_(const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH, const double dLr, const double dEpsilon) {
const T* grad = gradient.bufferAsT<T>();
const T* init = initState.bufferAsT<T>();
T* up = update.bufferAsT<T>();
T* st = stateH.bufferAsT<T>();
const T lr = static_cast<T>(dLr);
const T epsilon = static_cast<T>(dEpsilon);
bool bEws1 = 1 == gradient.ews() && 1 == update.ews() && 1 == stateH.ews() && 1 == initState.ews();
bool bSameOrdering = gradient.ordering() == update.ordering() && update.ordering() == stateH.ordering() && stateH.ordering() == initState.ordering();
if (bEws1 && bSameOrdering) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
st[i] = init[i] + grad[i] * grad[i];
up[i] = (lr * grad[i]) / (math::nd4j_sqrt<T, T>(st[i]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
bool bXZsame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), update.getShapeInfo());
bool bXInSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), initState.getShapeInfo());
bool bXStSame = shape::haveSameShapeAndStrides(gradient.getShapeInfo(), stateH.getShapeInfo());
auto func = PRAGMA_THREADS_FOR{
int coords[MAX_RANK];
for (auto i = start; i < stop; i++) {
shape::index2coordsCPU(start, i, gradient.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(gradient.getShapeInfo(), coords);
const auto zOffset = bXZsame ? xOffset : shape::getOffset(update.getShapeInfo(), coords);
const auto initOffset = bXInSame ? xOffset : shape::getOffset(initState.getShapeInfo(), coords);
const auto stOffset = bXStSame ? xOffset : shape::getOffset(stateH.getShapeInfo(), coords);
st[stOffset] = init[initOffset] + grad[xOffset] * grad[xOffset];
up[zOffset] = (lr * grad[xOffset]) / (math::nd4j_sqrt<T, T>(st[stOffset]) + epsilon);
}
};
samediff::Threads::parallel_for(func, 0, gradient.lengthOf(), 1);
return;
}
void updaterAdaGrad(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState, NDArray& update, NDArray& stateH,
const double dLr, const double dEpsilon) {
BUILD_SINGLE_SELECTOR(gradient.dataType(), adaGradUpdater_, (gradient, initState, update, stateH, dLr, dEpsilon), FLOAT_TYPES);
}
}
}
}