cavis/libnd4j/include/ops/declarable/generic/nn/layer_norm.cpp

149 lines
6.2 KiB
C++
Raw Normal View History

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 14:21:15 +02:00
//
// @author Paul Dubs
//
#include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
#if NOT_EXCLUDED(OP_layer_norm)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/reverse.h>
#include <ops/declarable/helpers/addBias.h>
2019-06-06 14:21:15 +02:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CONFIGURABLE_OP_IMPL(layer_norm, 2, 1, false, 0, -1) {
auto input = INPUT_VARIABLE(0);
auto gain = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
2019-06-06 14:21:15 +02:00
std::vector<int> axis = *block.getIArguments();
2021-02-01 06:31:20 +01:00
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
2019-06-06 14:21:15 +02:00
NDArray* bias = nullptr;
if (block.width() > 2) {
2019-06-06 14:21:15 +02:00
bias = INPUT_VARIABLE(2);
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
}
2019-06-06 14:21:15 +02:00
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
sd::ops::standardize standardizeOp;
2019-06-06 14:21:15 +02:00
std::vector<NDArray *> inputs = {input};
std::vector<NDArray *> outputs = {output};
std::vector<double> targs = {};
std::vector<bool> bargs = {};
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
// output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, output);
output->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *output);
if(bias != nullptr) {
// output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output);
// output->applyBroadcast(sd::broadcast::Add, {dimC}, bias);
helpers::addBias(block, *output, *bias, *output, isNCHW);
}
2019-06-06 14:21:15 +02:00
return Status::OK();
}
DECLARE_TYPES(layer_norm) {
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS});
getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(layer_norm_bp, 3, -1, false, 0, -1) {
auto input = INPUT_VARIABLE(0);
auto gain = INPUT_VARIABLE(1);
auto bias = block.width() == 4 ? INPUT_VARIABLE(2) : nullptr;
auto eps = block.width() == 4 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdg = OUTPUT_VARIABLE(1);
auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr;
2021-02-01 06:31:20 +01:00
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
2019-06-06 14:21:15 +02:00
std::vector<int> axis = *block.getIArguments();
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
if(bias != nullptr) {
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
// eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true);
eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
}
2019-06-06 14:21:15 +02:00
NDArray standardized(input->shapeInfo(), false, block.launchContext());
sd::ops::standardize standardizeOp;
2019-06-06 14:21:15 +02:00
std::vector<NDArray *> inputs = {input};
std::vector<NDArray *> outputs = {&standardized};
std::vector<double> targs = {};
std::vector<bool> bargs = {};
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
standardized.applyPairwiseTransform(sd::pairwise::Multiply, *eps, standardized);
standardized.reduceAlongDimension(sd::reduce::Sum, *dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
2019-06-06 14:21:15 +02:00
sd::ops::standardize_bp standardizeBp;
// eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx);
eps->applyBroadcast(sd::broadcast::Multiply, {dimC}, *gain, *dLdx);
2019-06-06 14:21:15 +02:00
auto dLdx_tmp = dLdx->dup();
Shyrma temp (#131) * - specifying template instantiation for certain types in float16 and bloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - polishing bfloat16 and float16 member functions template specialization Signed-off-by: Yurii <iuriish@yahoo.com> * - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections which have to do with and rvalue lvalue conversions Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantic in NDArray operators array +-/* array Signed-off-by: Yurii <iuriish@yahoo.com> * float16/bfloat16 tweaks Signed-off-by: raver119 <raver119@gmail.com> * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * - make float16 and bfloat16 to compile successfully on cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - do not use resources of view-like arrays when move semantics is applied Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of pointers in signatures NDArray methods 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::dup method Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of signature of NDArray::reduceAlongDimension method Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tensorsAlongDimension and diagonal methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::allTensorsAlongDimension Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduceAlongDimension 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyPairwiseTransform 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyTrueBroadcast 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::applyScalar and applyScalarArr Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::lambda methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::reduce3 methods 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::tileToShape methods Signed-off-by: Yurii <iuriish@yahoo.com> * - signature correction of NDArray::isShapeSameStrict method Signed-off-by: Yurii <iuriish@yahoo.com> * minor corrections in tests Signed-off-by: Yurii <iuriish@yahoo.com> * - replace reduce op in batchnorm mkldnn Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit templates instantiations for operator+(NDArray&&. const scalar) Signed-off-by: Yurii <iuriish@yahoo.com> * - corrections of casts in float16/bfloat16 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of input array A duplicate in svd cuda op Signed-off-by: Yurii <iuriish@yahoo.com> * - avoid available bug in svd cuda API Signed-off-by: Yurii <iuriish@yahoo.com> * - add temporary global memory buffer in svd cuda when calcUV = false and m != n Signed-off-by: Yurii <iuriish@yahoo.com> * - remove test with blfoat16 type for betainC Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts after master has been merged in Signed-off-by: Yurii <iuriish@yahoo.com> * - changed type of affected input array in fused_batch_norm Signed-off-by: Yurii <iuriish@yahoo.com> * - add several explicit type castings Signed-off-by: Yurii <iuriish@yahoo.com> * - add ND4J_EXPORT to operators Signed-off-by: Yurii <iuriish@yahoo.com> * - add explicit template types in instantiations of template arithm operators of NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - one more test fix Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>
2019-12-20 20:35:39 +01:00
std::vector<NDArray *> standardizeBpArgs = {input, &dLdx_tmp};
2019-06-06 14:21:15 +02:00
std::vector<NDArray *> standardizeBpOut = {dLdx};
standardizeBp.execute(standardizeBpArgs, standardizeBpOut, targs, longAxis, bargs);
return Status::OK();
}
DECLARE_TYPES(layer_norm_bp) {
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS});
getOpDescriptor()->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(layer_norm_bp) {
Nd4jLong *dLdx_shape;
COPY_SHAPE(inputShape->at(0), dLdx_shape);
Nd4jLong *dLdg_shape;
COPY_SHAPE(inputShape->at(1), dLdg_shape);
if(inputShape->size() > 3){
Nd4jLong *dLdb_shape;
COPY_SHAPE(inputShape->at(2), dLdb_shape);
return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape), CONSTANT(dLdb_shape));
}
return SHAPELIST(CONSTANT(dLdx_shape), CONSTANT(dLdg_shape));
}
}
}
#endif