Learning updaters for gradient (#335)
* libnd4j raw implementation of sgd upader
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and simple test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections after discussion
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j integrate applyScalar
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of rmsPropUpdater on cpu
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix operations declaration
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j rmsPropUpdater added, test cases for sgd, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some fixes and improvements for rmsPropUpdater based on Java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed cuda implementation, update tests and corrected behavior according java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaGrad updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor fix for ada grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several more fixes for ada_grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nesterovs updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed nesterovs updater behavior, several typos and rename file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor typo
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j ada max updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMax updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMaxUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaMax, added Adam Updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaDeltaUpdater added, minor fixes for adamUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaDeltaUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nadamUpdater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one more correction for nadam updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for nadam updater and added amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several typos fixed in amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and added f order support rmsProp updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added support of f order for all updaters and modify tests for testing in place
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed issues for updates when not in place mode used, added tests for f order
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added input shape checks
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for different cases handling
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some code clean up and optimize per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j updaters refactoring after review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* SgdUpdater wrapper
Signed-off-by: raver119 <raver119@gmail.com>
* first test
Signed-off-by: raver119 <raver119@gmail.com>
* RmsPropUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* NadamUpdater + NesterovsUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AmsGradUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdamUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater test added
Signed-off-by: raver119 <raver119@gmail.com>
* libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j next step to split operations implementation into separate files
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and minor corrections
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j revert some changes of split implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j forgot to add header file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* public default constructors
Signed-off-by: raver119 <raver119@gmail.com>
* ImportClassMapping updated
Signed-off-by: raver119 <raver119@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
2020-03-23 05:28:31 +01:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2019 Konduit K.K.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <system/op_boilerplate.h>
|
|
|
|
#include <ops/declarable/helpers/updatersHelpers.h>
|
|
|
|
#include <helpers/PointersManager.h>
|
|
|
|
#include <math/platformmath.h>
|
|
|
|
#include <math/templatemath.h>
|
|
|
|
|
|
|
|
namespace sd {
|
|
|
|
namespace ops {
|
|
|
|
namespace helpers {
|
|
|
|
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
template<typename T>
|
|
|
|
__global__ void nesterovsUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
|
|
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo, const T lr, const T momentum) {
|
|
|
|
|
|
|
|
const auto grad = reinterpret_cast<const T*>(vx);
|
|
|
|
const auto init = reinterpret_cast<const T*>(vin);
|
|
|
|
auto up = reinterpret_cast<T*>(vz);
|
|
|
|
auto st = reinterpret_cast<T*>(vst);
|
|
|
|
|
|
|
|
__shared__ Nd4jLong xLen;
|
|
|
|
__shared__ T momentumT;
|
|
|
|
__shared__ bool bEWS, bOrdering, bXZsame, bXInSame, bXStSame;
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
xLen = shape::length(xShapeInfo);
|
|
|
|
momentumT = (-momentum - 1);
|
|
|
|
|
|
|
|
bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
|
|
|
|
1 == shape::elementWiseStride(stShapeInfo) && 1 == shape::elementWiseStride(inShapeInfo);
|
|
|
|
bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(xShapeInfo) == shape::order(inShapeInfo) &&
|
|
|
|
shape::order(xShapeInfo) == shape::order(stShapeInfo);
|
|
|
|
|
|
|
|
bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
|
|
|
bXInSame = shape::haveSameShapeAndStrides(xShapeInfo, inShapeInfo);
|
|
|
|
bXStSame = shape::haveSameShapeAndStrides(xShapeInfo, stShapeInfo);
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
int coords[MAX_RANK];
|
|
|
|
|
|
|
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
|
|
|
|
|
|
|
|
auto xOffset = i, zOffset = i, initOffset = i, stOffset = i;
|
|
|
|
|
|
|
|
if (!bEWS || !bOrdering) {
|
|
|
|
|
|
|
|
shape::index2coords(i, xShapeInfo, coords);
|
|
|
|
xOffset = shape::getOffset(xShapeInfo, coords);
|
|
|
|
zOffset = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
|
|
|
|
initOffset = bXInSame ? xOffset : shape::getOffset(inShapeInfo, coords);
|
|
|
|
stOffset = bXStSame ? xOffset : shape::getOffset(stShapeInfo, coords);
|
|
|
|
}
|
|
|
|
|
|
|
|
T prevState = momentum * init[initOffset];
|
|
|
|
st[stOffset] = prevState - lr * grad[xOffset];
|
|
|
|
up[zOffset] = prevState + momentumT * st[stOffset];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
template<typename T>
|
|
|
|
linkage void nesterovsUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
|
|
|
|
const void* vx, const Nd4jLong* xShapeInfo, const void* vin, const Nd4jLong* inShapeInfo,
|
|
|
|
void* vz, const Nd4jLong* zShapeInfo, void* vst, const Nd4jLong* stShapeInfo,
|
|
|
|
const double dLr, const double dMomentum) {
|
|
|
|
|
|
|
|
const T lr = static_cast<T>(dLr);
|
|
|
|
const T momentum = static_cast<T>(dMomentum);
|
2020-05-09 07:06:14 +02:00
|
|
|
nesterovsUpdaterCuda<T><<<blocksPerGrid, threadsPerBlock, 256, * stream>>>(vx, xShapeInfo, vin, inShapeInfo,
|
Learning updaters for gradient (#335)
* libnd4j raw implementation of sgd upader
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and simple test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections after discussion
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j integrate applyScalar
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of rmsPropUpdater on cpu
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix operations declaration
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j rmsPropUpdater added, test cases for sgd, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some fixes and improvements for rmsPropUpdater based on Java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed cuda implementation, update tests and corrected behavior according java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaGrad updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor fix for ada grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several more fixes for ada_grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nesterovs updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed nesterovs updater behavior, several typos and rename file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor typo
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j ada max updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMax updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMaxUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaMax, added Adam Updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaDeltaUpdater added, minor fixes for adamUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaDeltaUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nadamUpdater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one more correction for nadam updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for nadam updater and added amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several typos fixed in amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and added f order support rmsProp updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added support of f order for all updaters and modify tests for testing in place
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed issues for updates when not in place mode used, added tests for f order
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added input shape checks
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for different cases handling
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some code clean up and optimize per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j updaters refactoring after review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* SgdUpdater wrapper
Signed-off-by: raver119 <raver119@gmail.com>
* first test
Signed-off-by: raver119 <raver119@gmail.com>
* RmsPropUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* NadamUpdater + NesterovsUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AmsGradUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdamUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater test added
Signed-off-by: raver119 <raver119@gmail.com>
* libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j next step to split operations implementation into separate files
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and minor corrections
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j revert some changes of split implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j forgot to add header file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* public default constructors
Signed-off-by: raver119 <raver119@gmail.com>
* ImportClassMapping updated
Signed-off-by: raver119 <raver119@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
2020-03-23 05:28:31 +01:00
|
|
|
vz, zShapeInfo, vst, stShapeInfo, lr, momentum);
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
|
|
void updaterNesterovs(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initState,
|
|
|
|
NDArray& update, NDArray& stateV, const double dLr, const double dMomentum) {
|
|
|
|
|
|
|
|
PointersManager manager(context, "nesterovsUpdater");
|
|
|
|
|
|
|
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
|
|
|
const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({ &update, &stateV }, { &gradient, &initState });
|
|
|
|
BUILD_SINGLE_SELECTOR(gradient.dataType(), nesterovsUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock,
|
2020-05-09 07:06:14 +02:00
|
|
|
context->getCudaStream(), gradient.specialBuffer(), gradient.specialShapeInfo(),
|
|
|
|
initState.specialBuffer(), initState.specialShapeInfo(),
|
|
|
|
update.specialBuffer(), update.specialShapeInfo(),
|
|
|
|
stateV.specialBuffer(), stateV.specialShapeInfo(), dLr, dMomentum), FLOAT_TYPES);
|
Learning updaters for gradient (#335)
* libnd4j raw implementation of sgd upader
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and simple test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections after discussion
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j integrate applyScalar
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of rmsPropUpdater on cpu
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix operations declaration
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j rmsPropUpdater added, test cases for sgd, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some fixes and improvements for rmsPropUpdater based on Java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed cuda implementation, update tests and corrected behavior according java tests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaGrad updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor fix for ada grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several more fixes for ada_grad
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nesterovs updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed nesterovs updater behavior, several typos and rename file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one minor typo
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j ada max updater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMax updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed several typos in adaMaxUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaMax, added Adam Updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j adaDeltaUpdater added, minor fixes for adamUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for adaDeltaUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j nadamUpdater added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j one more correction for nadam updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for nadam updater and added amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several typos fixed in amsGradUpdater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections and added f order support rmsProp updater
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added support of f order for all updaters and modify tests for testing in place
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed issues for updates when not in place mode used, added tests for f order
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j added input shape checks
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for different cases handling
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some code clean up and optimize per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j updaters refactoring after review
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* SgdUpdater wrapper
Signed-off-by: raver119 <raver119@gmail.com>
* first test
Signed-off-by: raver119 <raver119@gmail.com>
* RmsPropUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* NadamUpdater + NesterovsUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AmsGradUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdamUpdater added
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater
Signed-off-by: raver119 <raver119@gmail.com>
* AdaGradUpdater test added
Signed-off-by: raver119 <raver119@gmail.com>
* libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j next step to split operations implementation into separate files
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and minor corrections
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j revert some changes of split implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j forgot to add header file
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* public default constructors
Signed-off-by: raver119 <raver119@gmail.com>
* ImportClassMapping updated
Signed-off-by: raver119 <raver119@gmail.com>
Co-authored-by: raver119 <raver119@gmail.com>
2020-03-23 05:28:31 +01:00
|
|
|
NDArray::registerSpecialUse({ &update, &stateV }, { &gradient, &initState });
|
|
|
|
|
|
|
|
manager.synchronize();
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|