2021-02-09 05:16:31 +01:00
/*
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2.0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/
2019-06-06 14:21:15 +02:00
//
// @author raver119@gmail.com, created on 29/10/17.
// @author Yurii Shyrma (iuriish@yahoo.com)
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_batchnorm)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/batchnorm.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
2019-10-26 13:14:21 +02:00
CUSTOM_OP_IMPL ( batchnorm , 3 , 1 , false , 1 , 2 ) {
2019-06-06 14:21:15 +02:00
auto input = INPUT_VARIABLE ( 0 ) ;
auto mean = INPUT_VARIABLE ( 1 ) ;
auto variance = INPUT_VARIABLE ( 2 ) ;
NDArray * gamma = nullptr ;
NDArray * beta = nullptr ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
const double epsilon = T_ARG ( 0 ) ;
if ( applyScale )
gamma = INPUT_VARIABLE ( 3 ) ;
if ( applyOffset )
2019-10-26 13:14:21 +02:00
beta = INPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
2019-06-06 14:21:15 +02:00
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
const int inRank = input - > rankOf ( ) ;
// get axes args to normalize input array over
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( inRank - 1 ) ; // default dimension to reduce along is last dimension
2020-02-26 19:12:19 +01:00
const uint numOfAxes = axes . size ( ) ;
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( numOfAxes < = inRank , 0 , " BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly ! " , numOfAxes , inRank ) ;
2019-06-06 14:21:15 +02:00
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
2019-10-26 13:14:21 +02:00
std : : vector < Nd4jLong > expShape ;
if ( numOfAxes = = 1 )
expShape . push_back ( input - > sizeAt ( axes [ 0 ] ) ) ;
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape = std : : vector < Nd4jLong > ( inRank , 1 ) ;
for ( uint i = 0 ; i < numOfAxes ; + + i )
expShape [ axes [ i ] ] = input - > sizeAt ( axes [ i ] ) ;
}
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( mean - > isSameShape ( expShape ) , 0 , " BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ;
REQUIRE_TRUE ( variance - > isSameShape ( expShape ) , 0 , " BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
if ( gamma )
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( gamma - > isSameShape ( expShape ) , 0 , " BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( gamma ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
if ( beta )
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( beta - > isSameShape ( expShape ) , 0 , " BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( beta ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
// types of all input arrays should be the same
2020-02-26 19:12:19 +01:00
for ( unsigned long i = 1 ; i < block . width ( ) ; + + i )
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( INPUT_VARIABLE ( 0 ) - > dataType ( ) = = INPUT_VARIABLE ( i ) - > dataType ( ) , 0 , " BATCHNORM op: types of all input arrays should be the same ! " ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
nd4j_debug ( " MKL-DNN is not used for batchnorm! \n " , 0 ) ;
2019-06-06 14:21:15 +02:00
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
2019-11-13 15:15:18 +01:00
// auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
2020-03-02 10:49:41 +01:00
// auto m = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
2019-11-13 15:15:18 +01:00
2019-06-06 14:21:15 +02:00
helpers : : batchnorm ( input , mean , variance , gamma , beta , output , axes , epsilon ) ;
2019-11-13 15:15:18 +01:00
// NDArray stdInv = *v + epsilon;
// stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
// stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// if(applyScale)
// stdInv *= *gamma;
// // empty array with same shape as input
2020-03-02 10:49:41 +01:00
// input->applyBroadcast(sd::broadcast::Subtract, axes, m, output);
// output->applyBroadcast(sd::broadcast::Multiply, axes, &stdInv);
2019-11-13 15:15:18 +01:00
// if(applyOffset)
2020-03-02 10:49:41 +01:00
// output->applyBroadcast(sd::broadcast::Add, axes, beta);
2019-11-13 15:15:18 +01:00
// delete v;
// delete m;
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
2019-10-26 13:14:21 +02:00
DECLARE_TYPES ( batchnorm ) {
2019-06-06 14:21:15 +02:00
getOpDescriptor ( ) - > setAllowedInputTypes ( { ALL_FLOATS } ) - > setSameMode ( true ) ;
}
2019-10-26 13:14:21 +02:00
DECLARE_SHAPE_FN ( batchnorm ) {
2019-06-06 14:21:15 +02:00
auto inShapeInfo = inputShape - > at ( 0 ) ;
DataType outType = DataTypeUtils : : pickFloatingType ( ArrayOptions : : dataType ( inShapeInfo ) ) ;
2019-10-26 13:14:21 +02:00
2019-06-06 14:21:15 +02:00
auto outShapeInfo = ShapeBuilders : : copyShapeInfoAndType ( inShapeInfo , outType , false , block . getWorkspace ( ) ) ; // output shape is identical to input shape
return SHAPELIST ( CONSTANT ( outShapeInfo ) ) ;
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( batchnorm_bp , 4 , 3 , false , 1 , 2 ) {
2019-10-26 13:14:21 +02:00
NDArray * input = INPUT_VARIABLE ( 0 ) ;
NDArray * mean = INPUT_VARIABLE ( 1 ) ;
NDArray * variance = INPUT_VARIABLE ( 2 ) ;
NDArray * gamma = nullptr ;
NDArray * beta = nullptr ;
2019-11-13 15:15:18 +01:00
NDArray * dLdO = INPUT_VARIABLE ( block . width ( ) - 1 ) ; // next epsilon
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
NDArray * dLdI = OUTPUT_VARIABLE ( 0 ) ;
NDArray * dLdM = OUTPUT_VARIABLE ( 1 ) ;
NDArray * dLdV = OUTPUT_VARIABLE ( 2 ) ;
NDArray * dLdG = nullptr ;
NDArray * dLdB = nullptr ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
const float epsilon = T_ARG ( 0 ) ;
2019-06-06 14:21:15 +02:00
if ( applyScale ) {
2019-11-13 15:15:18 +01:00
gamma = INPUT_VARIABLE ( 3 ) ;
2019-06-06 14:21:15 +02:00
dLdG = OUTPUT_VARIABLE ( 3 ) ;
}
if ( applyOffset ) {
2019-11-13 15:15:18 +01:00
beta = INPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
2019-10-26 13:14:21 +02:00
dLdB = OUTPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
2019-06-06 14:21:15 +02:00
}
2019-10-26 13:14:21 +02:00
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
const int inRank = input - > rankOf ( ) ;
// get axes args to normalize input array over
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( inRank - 1 ) ; // default dimension to reduce along is last dimension
2020-02-26 19:12:19 +01:00
const uint numOfAxes = axes . size ( ) ;
2019-10-26 13:14:21 +02:00
REQUIRE_TRUE ( numOfAxes < = inRank , 0 , " BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly ! " , numOfAxes , inRank ) ;
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
std : : vector < Nd4jLong > expShape ;
if ( numOfAxes = = 1 )
expShape . push_back ( input - > sizeAt ( axes [ 0 ] ) ) ;
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape = std : : vector < Nd4jLong > ( inRank , 1 ) ;
for ( uint i = 0 ; i < numOfAxes ; + + i )
expShape [ axes [ i ] ] = input - > sizeAt ( axes [ i ] ) ;
}
REQUIRE_TRUE ( mean - > isSameShape ( expShape ) , 0 , " BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ;
REQUIRE_TRUE ( variance - > isSameShape ( expShape ) , 0 , " BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ;
if ( gamma )
REQUIRE_TRUE ( gamma - > isSameShape ( expShape ) , 0 , " BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( gamma ) . c_str ( ) ) ;
if ( beta )
REQUIRE_TRUE ( beta - > isSameShape ( expShape ) , 0 , " BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( beta ) . c_str ( ) ) ;
REQUIRE_TRUE ( input - > isSameShape ( dLdO ) , 0 , " BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( input ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdO ) . c_str ( ) ) ;
// types of all input arrays should be the same (except dLdO)
2020-02-26 19:12:19 +01:00
for ( unsigned long i = 1 ; i < block . width ( ) - 2 ; + + i )
2019-11-13 15:15:18 +01:00
REQUIRE_TRUE ( INPUT_VARIABLE ( 0 ) - > dataType ( ) = = INPUT_VARIABLE ( i ) - > dataType ( ) , 0 , " BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same ! " ) ;
2019-06-06 14:21:15 +02:00
// ***** calculations ***** //
2019-11-13 15:15:18 +01:00
// notations:
2020-01-28 16:23:07 +01:00
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output, g = dLdO
2019-11-13 15:15:18 +01:00
// stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions)
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
// derivatives:
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
// dfdx = gamma*stdInv*g;
// dfdm = -gamma*stdInv*g_sum;
// dmdx = 1/N;
// dvdx = 2 * (x - m) / N
// dvdm = -2 * [(x - m)]_sum / N
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
// finally:
// dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) )
// dLdG = (g * (x - m))_sum * stdInv
// dLdB = g_sum
2019-06-06 14:21:15 +02:00
2019-11-13 15:15:18 +01:00
// variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
2020-03-02 10:49:41 +01:00
// mean = input->reduceAlongDimension(sd::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
const auto excludedAxes = ShapeUtils : : evalDimsToExclude ( inRank , axes ) ;
2019-10-26 13:14:21 +02:00
const bool keepUnitiesInShape = inRank = = mean - > rankOf ( ) ;
2019-06-06 14:21:15 +02:00
2019-11-13 15:15:18 +01:00
// inverse batch size 1/N
2020-05-09 07:06:14 +02:00
const float Ninv = 1.f * shape : : tadLength ( input - > shapeInfo ( ) , axes . data ( ) , axes . size ( ) ) / input - > lengthOf ( ) ;
2019-11-13 15:15:18 +01:00
// input - mean
NDArray xMinusMean ( input ) ; // empty array with same shape as input
2020-03-02 10:49:41 +01:00
input - > applyBroadcast ( sd : : broadcast : : Subtract , axes , * mean , xMinusMean ) ;
2019-06-06 14:21:15 +02:00
2019-11-13 15:15:18 +01:00
// stdInv
NDArray stdInv = * variance + epsilon ;
2019-12-20 20:35:39 +01:00
stdInv . applyTransform ( transform : : Reciprocal , stdInv ) ; // 1 / (variance + epsilon)
stdInv . applyTransform ( transform : : Sqrt , stdInv ) ; // 1 / (variance + epsilon)^0.5
2019-11-13 15:15:18 +01:00
// dvdm (use dLdM as storage for dvdm)
2020-03-02 10:49:41 +01:00
xMinusMean . reduceAlongDimension ( sd : : reduce : : Sum , * dLdM , excludedAxes , keepUnitiesInShape ) ;
2019-11-13 15:15:18 +01:00
* dLdM * = - Ninv ;
// g_sum
2020-03-02 10:49:41 +01:00
auto gSum = dLdO - > reduceAlongDimension ( sd : : reduce : : Sum , excludedAxes , keepUnitiesInShape ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdB
if ( applyOffset )
2019-11-13 15:15:18 +01:00
dLdB - > assign ( gSum ) ;
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
// stdInv * (g - g_sum/N) (use dLdI as storage for this expression)
gSum * = Ninv ;
2020-03-02 10:49:41 +01:00
dLdO - > applyBroadcast ( sd : : broadcast : : Subtract , axes , gSum , * dLdI ) ;
dLdI - > applyBroadcast ( sd : : broadcast : : Multiply , axes , stdInv , * dLdI ) ;
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
// dLdV <- [g*(x - m)]_sum
2020-03-02 10:49:41 +01:00
( xMinusMean * * dLdO ) . reduceAlongDimension ( sd : : reduce : : Sum , * dLdV , excludedAxes , keepUnitiesInShape ) ;
2019-06-06 14:21:15 +02:00
// dLdG
2019-11-13 15:15:18 +01:00
* dLdV * = stdInv ;
if ( applyScale )
dLdG - > assign ( dLdV ) ;
2019-06-06 14:21:15 +02:00
2019-11-13 15:15:18 +01:00
// (2 / N) * dfdv (use dLdV as storage for dfdv)
* dLdV * = stdInv * stdInv ; // dLdV*stdInv * stdInv^2
* dLdV * = - Ninv ; // -0.5f * (2 / N);
// dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression)
2020-03-02 10:49:41 +01:00
xMinusMean . applyBroadcast ( sd : : broadcast : : Add , axes , * dLdM , xMinusMean ) ;
xMinusMean . applyBroadcast ( sd : : broadcast : : Multiply , axes , * dLdV , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
// dLdI
* dLdI + = xMinusMean ;
if ( applyScale )
2020-03-02 10:49:41 +01:00
dLdI - > applyBroadcast ( sd : : broadcast : : Multiply , axes , * gamma , * dLdI ) ;
2019-11-13 15:15:18 +01:00
* dLdM = 0 ; // put zeros so far
2019-10-26 13:14:21 +02:00
* dLdV = 0 ; // put zeros so far
2019-06-06 14:21:15 +02:00
2019-11-13 15:15:18 +01:00
// java code
// NDArray std = *variance + epsilon;
// std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
// std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// NDArray xMu(input);
2020-03-02 10:49:41 +01:00
// input->applyBroadcast(sd::broadcast::Subtract, axes, mean, &xMu);
2019-11-13 15:15:18 +01:00
// NDArray xHat(input);
2020-03-02 10:49:41 +01:00
// xMu.applyBroadcast(sd::broadcast::Multiply, axes, &std, &xHat);
2019-11-13 15:15:18 +01:00
// NDArray dxhat(input);
2020-03-02 10:49:41 +01:00
// dLdO->applyBroadcast(sd::broadcast::Multiply, axes, gamma, &dxhat);
2019-11-13 15:15:18 +01:00
// NDArray temp = dxhat*xMu;
// temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);
// *dLdV *= -0.5f * std*std*std;
// NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
// *dxmu1 *= -std;
// NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
// *dxmu2 *= *dLdV * (-2.f/N);
// NDArray dLdmu = *dxmu1 + *dxmu2;
// dLdmu *= (1.f /N);
// *dLdV *= (2.f/N);
2020-03-02 10:49:41 +01:00
// dxhat.applyBroadcast(sd::broadcast::Multiply, axes, &std);
// xMu.applyBroadcast(sd::broadcast::Multiply, axes, dLdV);
2019-11-13 15:15:18 +01:00
// dxhat += xMu;
2020-03-02 10:49:41 +01:00
// dxhat.applyBroadcast(sd::broadcast::Add, axes, &dLdmu, dLdI);
2019-11-13 15:15:18 +01:00
// delete dxmu1;
// delete dxmu2;
// xHat *= *dLdO;
// xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape);
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
2019-10-26 13:14:21 +02:00
DECLARE_TYPES ( batchnorm_bp ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( 0 , sd : : DataType : : ANY )
- > setAllowedInputTypes ( 1 , sd : : DataType : : ANY )
- > setAllowedInputTypes ( 2 , sd : : DataType : : ANY )
2019-10-26 13:14:21 +02:00
- > setAllowedInputTypes ( 3 , { ALL_FLOATS } )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( 4 , sd : : DataType : : ANY )
- > setAllowedInputTypes ( 5 , sd : : DataType : : ANY )
2019-10-26 13:14:21 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN ( batchnorm_bp ) {
2020-05-09 07:06:14 +02:00
Nd4jLong const * inShapeInfo = inputShape - > at ( 0 ) ;
Nd4jLong const * meanShapeInfo = inputShape - > at ( 1 ) ;
2019-10-26 13:14:21 +02:00
2019-06-06 14:21:15 +02:00
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
2019-10-26 13:14:21 +02:00
DataType outType = DataTypeUtils : : pickFloatingType ( ArrayOptions : : dataType ( inShapeInfo ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
auto shapes = SHAPELIST ( ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdI shapeInfo
2020-06-06 14:26:55 +02:00
shapes - > push_back ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( outType , inShapeInfo ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdM shapeInfo
2020-06-06 14:26:55 +02:00
shapes - > push_back ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( outType , meanShapeInfo ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdV shapeInfo (same as dLdM)
shapes - > push_back ( shapes - > at ( shapes - > size ( ) - 1 ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdG shapeInfo (same as dLdM)
if ( applyScale )
shapes - > push_back ( shapes - > at ( shapes - > size ( ) - 1 ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
// dLdB shapeInfo (same as dLdM)
if ( applyOffset )
shapes - > push_back ( shapes - > at ( shapes - > size ( ) - 1 ) ) ;
2019-06-06 14:21:15 +02:00
2019-10-26 13:14:21 +02:00
return shapes ;
2019-06-06 14:21:15 +02:00
}
}
}
# endif