/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com), created on 19.04.2018 // @author raver119@gmail.com // #include #include #include #include #include namespace sd { namespace ops { namespace helpers { /////////////////////////////////////////////////////////////////// template void static _softMaxDerivForVector(sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output) { const T* inBuff = reinterpret_cast(input); T* outBuff = reinterpret_cast(output); T max = -DataTypeUtils::max(); T sum = 0.; int length = shape::length(inShapeInfo); for (int i = 0; i < length; i++) { const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); max = sd::math::nd4j_max(max, inBuff[offset]); } for (int i = 0; i < length; i++) { const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); outBuff[offset] = sd::math::nd4j_exp(inBuff[offset] - max); sum += outBuff[offset]; } for (int i = 0; i < length; i++) { const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo); outBuff[offset] /= sum; outBuff[offset] *= (1.f - outBuff[offset]); // derivative } } /////////////////////////////////////////////////////////////////// void softmaxDerivative(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { const int rank = input.rankOf(); int temp; if(shape::isCommonVector(input.getShapeInfo(), temp)) { BUILD_SINGLE_SELECTOR(input.dataType(), _softMaxDerivForVector, (context, input.getBuffer(), input.getShapeInfo(), output.buffer()), FLOAT_TYPES); } else { auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output *= (1.f - output); // derivative } } /////////////////////////////////////////////////////////////////// template void logSoftMaxForVector_(void *input, Nd4jLong *inShapeInfo, void *output, Nd4jLong *outShapeInfo) { auto inBuff = reinterpret_cast(input); auto outBuff = reinterpret_cast(output); T max = -DataTypeUtils::max(); T sum = 0; auto inEWS = shape::elementWiseStride(inShapeInfo); auto length = shape::length(inShapeInfo); if (inEWS == 1) { for (Nd4jLong i = 0; i < length; i++) max = sd::math::nd4j_max(max, inBuff[i]); PRAGMA_OMP_SIMD_SUM(sum) for (Nd4jLong i = 0; i < length; i++) { outBuff[i] = sd::math::nd4j_exp(inBuff[i] - max); sum += outBuff[i]; } PRAGMA_OMP_SIMD for (Nd4jLong i = 0; i < length; i++) { outBuff[i] /= sum; outBuff[i] = sd::math::nd4j_log(outBuff[i]); } } else if (inEWS > 1) { PRAGMA_OMP_SIMD_MAX(max) for (Nd4jLong i = 0; i < length; i++) max = sd::math::nd4j_max(max, inBuff[i * inEWS]); PRAGMA_OMP_SIMD_SUM(sum) for (Nd4jLong i = 0; i < length; i++) { outBuff[i * inEWS] = sd::math::nd4j_exp(inBuff[i * inEWS] - max); sum += outBuff[i * inEWS]; } PRAGMA_OMP_SIMD for (Nd4jLong i = 0; i < length; i++) { outBuff[i * inEWS] /= sum; outBuff[i * inEWS] = sd::math::nd4j_log(outBuff[i * inEWS]); } } } /////////////////////////////////////////////////////////////////// void logSoftMaxForVector(sd::LaunchContext * context, const NDArray& input, NDArray& output) { if(!input.isVector() || !output.isVector()) throw std::runtime_error("ops::helpers::logSoftMaxForVector function input and output arrays must be vectors !"); auto xType = input.dataType(); BUILD_SINGLE_SELECTOR(xType, logSoftMaxForVector_, (input.getBuffer(), input.getShapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// void prelu(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { const Nd4jLong inputLen = input.lengthOf(); const Nd4jLong* inputShapeInfo = input.getShapeInfo(); const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo(); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { // FIXME: double! double x = input.e(i); if (x < 0.0) { // FIXME: double output.p(i, (x * alpha.e(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo)))); } else output.p(i, x); } }; sd::Threads::parallel_for(func, 0, inputLen); } ////////////////////////////////////////////////////////////////////////// void preluBP(sd::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) { const Nd4jLong inputLen = input.lengthOf(); const Nd4jLong* inputShapeInfo = input.getShapeInfo(); const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo(); dLdA.assign(0.0f); for(Nd4jLong i = 0; i < inputLen; ++i) { // FIXME: double double x = input.e(i); double grO = dLdO.e(i); if(x < 0.0) { Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); dLdI.p(i, grO * alpha.e(alphaInd)); double prevVal = dLdA.e(alphaInd); prevVal += (grO * x); dLdA.p(alphaInd, prevVal); } else dLdI.p(i, grO); } } bool checkAlphaShapeLen(std::vector const& expectedShape, Nd4jLong shapeLen) { Nd4jLong expectedAlphaLen = std::accumulate(expectedShape.cbegin(), expectedShape.cend(), 1, std::multiplies()); return expectedAlphaLen == shapeLen; } template static void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { auto routine = LAMBDA_T(_x, threshold) { return _x > (T)threshold? _x: (T)0.f; }; const_cast(input).applyLambda(routine, output); } void thresholdRelu(sd::LaunchContext * context, NDArray const& input, double threshold, NDArray& output) { BUILD_SINGLE_SELECTOR(input.dataType(), thresholdRelu_, (input, threshold, output), FLOAT_TYPES); } template static void thresholdReluDerivative_(sd::LaunchContext * context, NDArray* input, double theta, NDArray* dLdO, NDArray* output) { auto derivative = LAMBDA_TT(_x, grO, theta) {if (_x > theta) return grO; else return static_cast(0); }; input->applyPairwiseLambda(*dLdO, derivative, *output); } void thresholdReluDerivative(sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (context, input, threshold, dLdO, output), FLOAT_TYPES); } /////////////////////////////////////////////////////////////////// void logSoftmax(sd::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { const int rank = input.rankOf(); if(input.isVector()) { if(rank == 1 || input.sizeAt(dimension) != 1) { BUILD_SINGLE_SELECTOR(input.dataType(), logSoftMaxForVector_, (input.getBuffer(), input.getShapeInfo(), output.buffer(), output.shapeInfo()), FLOAT_TYPES); } else output = 0.; } else { auto maxAlongDim = const_cast(input).reduceAlongDimension(reduce::Max, {dimension}, true); (input - maxAlongDim).applyTransform(transform::Exp, output); // output contains exponents temporarily auto sumAlongDim = output.reduceAlongDimension(reduce::Sum, {dimension}, true); output /= sumAlongDim; output.applyTransform(transform::Log, output); } } BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (sd::LaunchContext * context, NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void logSoftMaxForVector_, (void *input, Nd4jLong *inShapeInfo, void *output, Nd4jLong *outShapeInfo), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void _softMaxDerivForVector, (sd::LaunchContext * context, const void *input, const Nd4jLong *inShapeInfo, void *output), FLOAT_TYPES); } } }