* libnd4j raw implementation of sgd upader Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections and simple test added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections after discussion Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j integrate applyScalar Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j raw implementation of rmsPropUpdater on cpu Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix operations declaration Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j rmsPropUpdater added, test cases for sgd, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some fixes and improvements for rmsPropUpdater based on Java tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed cuda implementation, update tests and corrected behavior according java tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j adaGrad updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one minor fix for ada grad Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several more fixes for ada_grad Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j nesterovs updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed nesterovs updater behavior, several typos and rename file Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one minor typo Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j ada max updater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos in adaMax updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed several typos in adaMaxUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for adaMax, added Adam Updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j adaDeltaUpdater added, minor fixes for adamUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for adaDeltaUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j nadamUpdater added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j one more correction for nadam updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several fixes for nadam updater and added amsGradUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j several typos fixed in amsGradUpdater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections and added f order support rmsProp updater Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added support of f order for all updaters and modify tests for testing in place Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed issues for updates when not in place mode used, added tests for f order Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j added input shape checks Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections for different cases handling Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some code clean up and optimize per request Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j updaters refactoring after review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * SgdUpdater wrapper Signed-off-by: raver119 <raver119@gmail.com> * first test Signed-off-by: raver119 <raver119@gmail.com> * RmsPropUpdater added Signed-off-by: raver119 <raver119@gmail.com> * NadamUpdater + NesterovsUpdater Signed-off-by: raver119 <raver119@gmail.com> * AmsGradUpdater Signed-off-by: raver119 <raver119@gmail.com> * AdamUpdater added Signed-off-by: raver119 <raver119@gmail.com> * AdaGradUpdater + AdaDeltaUpdater + AdaMaxUpdater Signed-off-by: raver119 <raver119@gmail.com> * AdaGradUpdater test added Signed-off-by: raver119 <raver119@gmail.com> * libnd4j remove input parameters parsing through NDArray, split implementation of helpers to separate files, added some rename, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j next step to split operations implementation into separate files Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master and minor corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j revert some changes of split implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j forgot to add header file Signed-off-by: Oleg <oleg.semeniv@gmail.com> * public default constructors Signed-off-by: raver119 <raver119@gmail.com> * ImportClassMapping updated Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>
		
			
				
	
	
		
			140 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			140 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
/*******************************************************************************
 | 
						|
 * Copyright (c) 2019 Konduit K.K.
 | 
						|
 *
 | 
						|
 * This program and the accompanying materials are made available under the
 | 
						|
 * terms of the Apache License, Version 2.0 which is available at
 | 
						|
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
						|
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
						|
 * License for the specific language governing permissions and limitations
 | 
						|
 * under the License.
 | 
						|
 *
 | 
						|
 * SPDX-License-Identifier: Apache-2.0
 | 
						|
 ******************************************************************************/
 | 
						|
 | 
						|
//
 | 
						|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
 | 
						|
//
 | 
						|
 | 
						|
#include <system/op_boilerplate.h>
 | 
						|
#include <ops/declarable/helpers/updatersHelpers.h>
 | 
						|
#include <helpers/PointersManager.h>
 | 
						|
#include <math/platformmath.h>
 | 
						|
#include <math/templatemath.h>
 | 
						|
 | 
						|
namespace sd    {
 | 
						|
namespace ops     {
 | 
						|
namespace helpers {
 | 
						|
 | 
						|
///////////////////////////////////////////////////////////////////
 | 
						|
template<typename T>
 | 
						|
__global__ void adamUpdaterCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm,
 | 
						|
                                const Nd4jLong* inmShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vstV, 
 | 
						|
                                const Nd4jLong* stvShapeInfo, void* vstM, const Nd4jLong* stmShapeInfo,
 | 
						|
                                const T lr, const T beta1, const T beta2, const T epsilon, const T iteration) {
 | 
						|
 | 
						|
    const auto grad = reinterpret_cast<const T*>(vx);
 | 
						|
    const auto initU = reinterpret_cast<const T*>(vinv);
 | 
						|
    const auto initM = reinterpret_cast<const T*>(vinm);
 | 
						|
 | 
						|
    auto up = reinterpret_cast<T*>(vz);
 | 
						|
    auto stU = reinterpret_cast<T*>(vstV);
 | 
						|
    auto stM = reinterpret_cast<T*>(vstM);
 | 
						|
 | 
						|
    __shared__ Nd4jLong xLen;
 | 
						|
    __shared__ T epsilonT;
 | 
						|
    __shared__ bool bEWS, bOrdering, bXZsame, bXInUSame, bXStUSame, bXInMSame, bXStMSame;
 | 
						|
 | 
						|
    if (threadIdx.x == 0) {
 | 
						|
        xLen = shape::length(xShapeInfo);
 | 
						|
 | 
						|
        T beta1T = sd::math::nd4j_pow<T, T, T>(beta1, (iteration + 1));
 | 
						|
        T beta2T = sd::math::nd4j_pow<T, T, T>(beta2, (iteration + 1));
 | 
						|
 | 
						|
        epsilonT =  lr * sd::math::nd4j_sqrt<T, T>(1. - beta2T) / (1.0 - beta1T);
 | 
						|
        if (sd::math::nd4j_isnan(epsilonT) || 0 == epsilonT || sd::math::nd4j_isinf(epsilonT))
 | 
						|
            epsilonT = epsilon;
 | 
						|
        
 | 
						|
        bEWS = 1 == shape::elementWiseStride(xShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo) &&
 | 
						|
                1 == shape::elementWiseStride(stmShapeInfo) && 1 == shape::elementWiseStride(inmShapeInfo) &&
 | 
						|
                1 == shape::elementWiseStride(stvShapeInfo) && 1 == shape::elementWiseStride(invShapeInfo);
 | 
						|
        bOrdering = shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(stmShapeInfo) &&
 | 
						|
                    shape::order(stmShapeInfo) == shape::order(inmShapeInfo) && shape::order(inmShapeInfo) == shape::order(stvShapeInfo) &&
 | 
						|
                    shape::order(stvShapeInfo) == shape::order(invShapeInfo);
 | 
						|
        
 | 
						|
        bXZsame = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
 | 
						|
        bXInUSame = shape::haveSameShapeAndStrides(xShapeInfo, invShapeInfo);
 | 
						|
        bXStUSame = shape::haveSameShapeAndStrides(xShapeInfo, stvShapeInfo);
 | 
						|
        bXInMSame = shape::haveSameShapeAndStrides(xShapeInfo, inmShapeInfo);
 | 
						|
        bXStMSame = shape::haveSameShapeAndStrides(xShapeInfo, stmShapeInfo);
 | 
						|
    }
 | 
						|
    __syncthreads();
 | 
						|
 | 
						|
    int coords[MAX_RANK];
 | 
						|
 | 
						|
    for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) {
 | 
						|
 | 
						|
        auto xOffset = i, zOffset = i, initMOffset = i, initUOffset = i, stMOffset = i, stUOffset = i;
 | 
						|
 | 
						|
        if (!bEWS || !bOrdering){
 | 
						|
 | 
						|
            shape::index2coords(i, xShapeInfo, coords);
 | 
						|
            xOffset  = shape::getOffset(xShapeInfo, coords);
 | 
						|
            zOffset  = bXZsame ? xOffset : shape::getOffset(zShapeInfo, coords);
 | 
						|
            initUOffset = bXInUSame ? xOffset : shape::getOffset(invShapeInfo, coords);
 | 
						|
            stUOffset = bXStUSame ? xOffset : shape::getOffset(stvShapeInfo, coords);
 | 
						|
            initMOffset = bXInMSame ? xOffset : shape::getOffset(inmShapeInfo, coords);
 | 
						|
            stMOffset = bXStMSame ? xOffset : shape::getOffset(stmShapeInfo, coords);
 | 
						|
        }
 | 
						|
 | 
						|
        stM[stMOffset] = beta1 * initM[initMOffset] + grad[xOffset] * (1 - beta1);
 | 
						|
        stU[stUOffset] = beta2 * initU[initUOffset] + grad[xOffset] * grad[xOffset] * (1 - beta2);
 | 
						|
 | 
						|
        up[zOffset] = (stM[stMOffset] * epsilonT) / ( sd::math::nd4j_sqrt<T, T>(stU[stUOffset]) + epsilon);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
///////////////////////////////////////////////////////////////////
 | 
						|
template<typename T>
 | 
						|
linkage void adamUpdaterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo,
 | 
						|
    const void* vinv, const Nd4jLong* invShapeInfo, const void* vinm, const Nd4jLong* inmShapeInfo,
 | 
						|
    void* vz, const Nd4jLong* zShapeInfo, void* vstV, const Nd4jLong* stvShapeInfo, 
 | 
						|
    void* vstM, const Nd4jLong* stmShapeInfo, const double dLr, const double dBeta1, const double dBeta2, const double dEpsilon, const int nIteration) {
 | 
						|
    
 | 
						|
    const T lr = static_cast<T>(dLr);
 | 
						|
    const T beta1 = static_cast<T>(dBeta1);
 | 
						|
    const T beta2 = static_cast<T>(dBeta2);
 | 
						|
    const T epsilon = static_cast<T>(dEpsilon);
 | 
						|
    const T iteration = static_cast<T>(nIteration);
 | 
						|
    adamUpdaterCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, vinv, invShapeInfo, vinm, inmShapeInfo, 
 | 
						|
        vz, zShapeInfo, vstV, stvShapeInfo, vstM, stmShapeInfo, lr, beta1, beta2, epsilon, iteration);
 | 
						|
}
 | 
						|
 | 
						|
///////////////////////////////////////////////////////////////////
 | 
						|
void updaterAdam(sd::LaunchContext* context, const NDArray& gradient, const NDArray& initStateU, const NDArray& initStateM, 
 | 
						|
                 NDArray& update, NDArray& stateU, NDArray& stateM,  const double dLr, const double dBeta1, const double dBeta2, 
 | 
						|
                 const double dEpsilon, const int nIteration) {
 | 
						|
 | 
						|
    PointersManager manager(context, "adamUpdater");
 | 
						|
 | 
						|
    const int threadsPerBlock = MAX_NUM_THREADS / 4;
 | 
						|
    const int blocksPerGrid = (gradient.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
 | 
						|
 | 
						|
    NDArray::prepareSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
 | 
						|
 | 
						|
    BUILD_SINGLE_SELECTOR(gradient.dataType(), adamUpdaterCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), gradient.getSpecialBuffer(), gradient.getSpecialShapeInfo(),
 | 
						|
        initStateU.getSpecialBuffer(), initStateU.getSpecialShapeInfo(), initStateM.getSpecialBuffer(), initStateM.getSpecialShapeInfo(),
 | 
						|
        update.getSpecialBuffer(), update.getSpecialShapeInfo(), stateU.getSpecialBuffer(), stateU.getSpecialShapeInfo(), 
 | 
						|
        stateM.getSpecialBuffer(), stateM.getSpecialShapeInfo(),  dLr, dBeta1, dBeta2, dEpsilon, nIteration), FLOAT_TYPES);
 | 
						|
 | 
						|
    NDArray::registerSpecialUse({ &update, &stateU, &stateM }, { &gradient, &initStateU, &initStateM });
 | 
						|
 | 
						|
    manager.synchronize();
 | 
						|
}
 | 
						|
 | 
						|
}
 | 
						|
}
 | 
						|
}
 |