2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2019 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Paul Dubs
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_layer_norm)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/reverse.h>
2019-09-11 19:12:09 +02:00
# include <ops/declarable/helpers/addBias.h>
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
CONFIGURABLE_OP_IMPL ( layer_norm , 2 , 1 , false , 0 , - 1 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto gain = INPUT_VARIABLE ( 1 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
2019-08-26 18:37:05 +02:00
2019-06-06 14:21:15 +02:00
std : : vector < int > axis = * block . getIArguments ( ) ;
2021-02-01 06:31:20 +01:00
const bool isNCHW = block . getBArguments ( ) - > size ( ) > 0 ? B_ARG ( 0 ) : true ; // 0-NCHW, 1-NHWC
2019-08-26 18:37:05 +02:00
const int dimC = isNCHW ? 1 : input - > rankOf ( ) - 1 ;
2019-08-27 18:57:59 +02:00
REQUIRE_TRUE ( gain - > rankOf ( ) = = 1 & & gain - > sizeAt ( 0 ) = = input - > sizeAt ( dimC ) , 0 , " LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead ! " , input - > sizeAt ( dimC ) , ShapeUtils : : shapeAsString ( gain ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
NDArray * bias = nullptr ;
2019-08-27 18:57:59 +02:00
if ( block . width ( ) > 2 ) {
2019-06-06 14:21:15 +02:00
bias = INPUT_VARIABLE ( 2 ) ;
2019-08-27 18:57:59 +02:00
REQUIRE_TRUE ( bias - > rankOf ( ) = = 1 & & bias - > sizeAt ( 0 ) = = input - > sizeAt ( dimC ) , 0 , " LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead ! " , input - > sizeAt ( dimC ) , ShapeUtils : : shapeAsString ( bias ) . c_str ( ) ) ;
}
2019-06-06 14:21:15 +02:00
std : : vector < Nd4jLong > longAxis = ArrayUtils : : toLongVector ( axis ) ;
2020-03-02 10:49:41 +01:00
sd : : ops : : standardize standardizeOp ;
2019-06-06 14:21:15 +02:00
std : : vector < NDArray * > inputs = { input } ;
std : : vector < NDArray * > outputs = { output } ;
std : : vector < double > targs = { } ;
std : : vector < bool > bargs = { } ;
standardizeOp . execute ( inputs , outputs , targs , longAxis , bargs ) ;
2020-03-02 10:49:41 +01:00
// output->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, output);
output - > applyBroadcast ( sd : : broadcast : : Multiply , { dimC } , * gain , * output ) ;
2019-08-26 18:37:05 +02:00
if ( bias ! = nullptr ) {
2020-03-02 10:49:41 +01:00
// output->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), bias, output);
// output->applyBroadcast(sd::broadcast::Add, {dimC}, bias);
2019-09-11 19:12:09 +02:00
helpers : : addBias ( block , * output , * bias , * output , isNCHW ) ;
2019-08-26 18:37:05 +02:00
}
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
DECLARE_TYPES ( layer_norm ) {
getOpDescriptor ( ) - > setAllowedInputTypes ( { ALL_FLOATS } ) ;
getOpDescriptor ( ) - > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
CUSTOM_OP_IMPL ( layer_norm_bp , 3 , - 1 , false , 0 , - 1 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto gain = INPUT_VARIABLE ( 1 ) ;
auto bias = block . width ( ) = = 4 ? INPUT_VARIABLE ( 2 ) : nullptr ;
auto eps = block . width ( ) = = 4 ? INPUT_VARIABLE ( 3 ) : INPUT_VARIABLE ( 2 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
auto dLdg = OUTPUT_VARIABLE ( 1 ) ;
auto dLdb = block . width ( ) = = 4 ? OUTPUT_VARIABLE ( 2 ) : nullptr ;
2021-02-01 06:31:20 +01:00
const bool isNCHW = block . getBArguments ( ) - > size ( ) > 0 ? B_ARG ( 0 ) : true ; // 0-NCHW, 1-NHWC
2019-08-26 18:37:05 +02:00
const int dimC = isNCHW ? 1 : input - > rankOf ( ) - 1 ;
2019-08-27 18:57:59 +02:00
REQUIRE_TRUE ( gain - > rankOf ( ) = = 1 & & gain - > sizeAt ( 0 ) = = input - > sizeAt ( dimC ) , 0 , " LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead ! " , input - > sizeAt ( dimC ) , ShapeUtils : : shapeAsString ( gain ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
std : : vector < int > axis = * block . getIArguments ( ) ;
std : : vector < Nd4jLong > longAxis = ArrayUtils : : toLongVector ( axis ) ;
2019-08-26 18:37:05 +02:00
if ( bias ! = nullptr ) {
2019-08-27 18:57:59 +02:00
REQUIRE_TRUE ( bias - > rankOf ( ) = = 1 & & bias - > sizeAt ( 0 ) = = input - > sizeAt ( dimC ) , 0 , " LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead ! " , input - > sizeAt ( dimC ) , ShapeUtils : : shapeAsString ( bias ) . c_str ( ) ) ;
2020-03-02 10:49:41 +01:00
// eps->reduceAlongDimension(sd::reduce::Sum, *dLdb, {0}, true);
eps - > reduceAlongDimension ( sd : : reduce : : Sum , * dLdb , ShapeUtils : : evalDimsToExclude ( input - > rankOf ( ) , { dimC } ) ) ;
2019-08-26 18:37:05 +02:00
}
2019-06-06 14:21:15 +02:00
NDArray standardized ( input - > shapeInfo ( ) , false , block . launchContext ( ) ) ;
2020-03-02 10:49:41 +01:00
sd : : ops : : standardize standardizeOp ;
2019-06-06 14:21:15 +02:00
std : : vector < NDArray * > inputs = { input } ;
std : : vector < NDArray * > outputs = { & standardized } ;
std : : vector < double > targs = { } ;
std : : vector < bool > bargs = { } ;
standardizeOp . execute ( inputs , outputs , targs , longAxis , bargs ) ;
2020-03-02 10:49:41 +01:00
standardized . applyPairwiseTransform ( sd : : pairwise : : Multiply , * eps , standardized ) ;
standardized . reduceAlongDimension ( sd : : reduce : : Sum , * dLdg , ShapeUtils : : evalDimsToExclude ( input - > rankOf ( ) , { dimC } ) ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
sd : : ops : : standardize_bp standardizeBp ;
// eps->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), gain, dLdx);
eps - > applyBroadcast ( sd : : broadcast : : Multiply , { dimC } , * gain , * dLdx ) ;
2019-06-06 14:21:15 +02:00
auto dLdx_tmp = dLdx - > dup ( ) ;
2019-12-20 20:35:39 +01:00
std : : vector < NDArray * > standardizeBpArgs = { input , & dLdx_tmp } ;
2019-06-06 14:21:15 +02:00
std : : vector < NDArray * > standardizeBpOut = { dLdx } ;
standardizeBp . execute ( standardizeBpArgs , standardizeBpOut , targs , longAxis , bargs ) ;
return Status : : OK ( ) ;
}
DECLARE_TYPES ( layer_norm_bp ) {
getOpDescriptor ( ) - > setAllowedInputTypes ( { ALL_FLOATS } ) ;
getOpDescriptor ( ) - > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( layer_norm_bp ) {
Nd4jLong * dLdx_shape ;
COPY_SHAPE ( inputShape - > at ( 0 ) , dLdx_shape ) ;
Nd4jLong * dLdg_shape ;
COPY_SHAPE ( inputShape - > at ( 1 ) , dLdg_shape ) ;
if ( inputShape - > size ( ) > 3 ) {
Nd4jLong * dLdb_shape ;
COPY_SHAPE ( inputShape - > at ( 2 ) , dLdb_shape ) ;
return SHAPELIST ( CONSTANT ( dLdx_shape ) , CONSTANT ( dLdg_shape ) , CONSTANT ( dLdb_shape ) ) ;
}
return SHAPELIST ( CONSTANT ( dLdx_shape ) , CONSTANT ( dLdg_shape ) ) ;
}
}
}
# endif