2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
2019-09-11 19:12:09 +02:00
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
2019-06-06 14:21:15 +02:00
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_biasadd)
# include <ops/declarable/CustomOperations.h>
2019-09-11 19:12:09 +02:00
# include <ops/declarable/helpers/addBias.h>
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-09-11 19:12:09 +02:00
namespace ops {
////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( biasadd , 2 , 1 , true , 0 , 0 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto bias = INPUT_VARIABLE ( 1 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const bool isNCHW = ! block . getBArguments ( ) - > empty ( ) ? B_ARG ( 0 ) : false ;
const int channelDim = isNCHW ? 1 : input - > rankOf ( ) - 1 ; // second or last
REQUIRE_TRUE ( bias - > rankOf ( ) = = 1 , 0 , " BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i instead ! " , bias - > rankOf ( ) ) ;
REQUIRE_TRUE ( bias - > sizeAt ( 0 ) = = input - > sizeAt ( channelDim ) , 0 , " BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not suitable for broadcast operation along channel dimension %i ! " , ShapeUtils : : shapeAsString ( bias ) . c_str ( ) , ShapeUtils : : shapeAsString ( input ) . c_str ( ) , channelDim ) ;
REQUIRE_TRUE ( output - > isSameShape ( input ) , 0 , " BIASADD CUSTOM_OP: wrong shape of output array, expected is %s but got %s instead ! " , ShapeUtils : : shapeAsString ( input ) . c_str ( ) , ShapeUtils : : shapeAsString ( output ) . c_str ( ) ) ;
helpers : : addBias ( block , * input , * bias , * output , isNCHW ) ;
2020-03-02 10:49:41 +01:00
// input->applyBroadcast(sd::broadcast::Add, {channelDim}, bias, output);
2019-09-11 19:12:09 +02:00
return Status : : OK ( ) ;
}
DECLARE_SYN ( bias_add , biasadd ) ;
////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN ( biasadd ) {
auto xShape = inputShape - > at ( 0 ) ;
auto yShape = inputShape - > at ( 1 ) ;
auto dtype = ArrayOptions : : dataType ( yShape ) ;
2020-06-06 14:26:55 +02:00
return SHAPELIST ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( ShapeDescriptor ( xShape , dtype ) ) ) ;
2019-09-11 19:12:09 +02:00
}
DECLARE_TYPES ( biasadd ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-09-11 19:12:09 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( biasadd_bp , 3 , 2 , false , 0 , 0 ) {
2019-11-03 11:37:19 +01:00
2019-09-11 19:12:09 +02:00
auto input = INPUT_VARIABLE ( 0 ) ;
2019-11-03 11:37:19 +01:00
auto bias = INPUT_VARIABLE ( 1 ) ;
auto gradO = INPUT_VARIABLE ( 2 ) ;
2019-09-11 19:12:09 +02:00
2019-11-03 11:37:19 +01:00
auto gradI = OUTPUT_VARIABLE ( 0 ) ;
2019-09-11 19:12:09 +02:00
auto gradB = OUTPUT_VARIABLE ( 1 ) ;
2019-11-03 11:37:19 +01:00
const bool isNCHW = ! block . getBArguments ( ) - > empty ( ) ? B_ARG ( 0 ) : false ;
const int channelDim = isNCHW ? 1 : input - > rankOf ( ) - 1 ; // second or last
2019-09-11 19:12:09 +02:00
2019-11-03 11:37:19 +01:00
gradI - > assign ( gradO ) ;
2019-09-11 19:12:09 +02:00
2020-03-02 10:49:41 +01:00
gradO - > reduceAlongDimension ( sd : : reduce : : Sum , * gradB , ShapeUtils : : evalDimsToExclude ( gradO - > rankOf ( ) , { channelDim } ) ) ;
2019-09-11 19:12:09 +02:00
return ND4J_STATUS_OK ;
}
DECLARE_SYN ( BiasAddGrad , biasadd_bp ) ;
2019-11-03 11:37:19 +01:00
////////////////////////////////////////////////////////////////////
2019-09-11 19:12:09 +02:00
DECLARE_SHAPE_FN ( biasadd_bp ) {
auto input = inputShape - > at ( 0 ) ;
auto bias = inputShape - > at ( 1 ) ;
Nd4jLong * epsShape ;
Nd4jLong * gradShape ;
COPY_SHAPE ( input , epsShape ) ;
COPY_SHAPE ( bias , gradShape ) ;
return SHAPELIST ( CONSTANT ( epsShape ) , CONSTANT ( gradShape ) ) ;
}
DECLARE_TYPES ( biasadd_bp ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-09-11 19:12:09 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
}
2019-06-06 14:21:15 +02:00
}
# endif