2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_multiply)
# include <ops/declarable/CustomOperations.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
BROADCASTABLE_OP_IMPL ( multiply , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto y = INPUT_VARIABLE ( 1 ) ;
auto z = OUTPUT_VARIABLE ( 0 ) ;
BROADCAST_CHECK_EMPTY ( x , y , z ) ;
2020-05-09 07:06:14 +02:00
const Nd4jLong * zShapeInfo = nullptr ;
const bool areShapesBroadcastable = ShapeUtils : : evalBroadcastShapeInfo ( x - > shapeInfo ( ) , y - > shapeInfo ( ) , true , zShapeInfo , block . getWorkspace ( ) ) ;
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( areShapesBroadcastable , 0 , " MULTIPLY OP: the shapes of x %s and y %s are not suitable for broadcast ! " , ShapeUtils : : shapeAsString ( x ) . c_str ( ) , ShapeUtils : : shapeAsString ( y ) . c_str ( ) ) ;
2020-03-02 10:49:41 +01:00
auto tZ = BroadcastHelper : : broadcastApply ( sd : : BroadcastOpsTuple : : Multiply ( ) , x , y , z ) ;
2019-06-06 14:21:15 +02:00
if ( tZ = = nullptr )
return ND4J_STATUS_KERNEL_FAILURE ;
else if ( tZ ! = z )
throw std : : runtime_error ( " multiply: result was replaced " ) ;
return Status : : OK ( ) ;
}
DECLARE_SYN ( Mul , multiply ) ;
DECLARE_TYPES ( multiply ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( 0 , DataType : : ANY )
- > setAllowedInputTypes ( 1 , DataType : : ANY )
- > setAllowedOutputTypes ( 0 , DataType : : INHERIT ) ;
}
DECLARE_TYPES ( multiply_bp ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( DataType : : ANY )
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
///////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( multiply_bp , 3 , 2 , false , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto y = INPUT_VARIABLE ( 1 ) ;
auto dLdz = INPUT_VARIABLE ( 2 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
auto dLdy = OUTPUT_VARIABLE ( 1 ) ;
2020-05-09 07:06:14 +02:00
const Nd4jLong * dLdzShapeInfo = nullptr ;
const bool areShapesBroadcastable = ShapeUtils : : evalBroadcastShapeInfo ( x - > shapeInfo ( ) , y - > shapeInfo ( ) , true , dLdzShapeInfo , block . getWorkspace ( ) ) ;
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( areShapesBroadcastable , 0 , " MULTIPLY_BP OP: the shapes of x %s and y %s are not suitable for broadcast ! " , ShapeUtils : : shapeAsString ( x ) . c_str ( ) , ShapeUtils : : shapeAsString ( y ) . c_str ( ) ) ;
REQUIRE_TRUE ( shape : : equalsSoft ( dLdz - > shapeInfo ( ) , dLdzShapeInfo ) , 0 , " MULTIPLY_BP OP: wrong shape of next epsilon array (dLdOut), expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( dLdzShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdz ) . c_str ( ) ) ;
const Nd4jLong xLen = x - > lengthOf ( ) ;
const Nd4jLong yLen = y - > lengthOf ( ) ;
if ( x - > isScalar ( ) & & y - > isScalar ( ) ) { // both are scalars
2019-12-20 20:35:39 +01:00
y - > applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdx ) ;
x - > applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdy ) ;
2019-06-06 14:21:15 +02:00
//dLdx->assign((*y) * (*dLdz));
//dLdy->assign((*x) * (*dLdz));
}
else if ( x - > isScalar ( ) ) { // x is scalar and y is not
dLdx - > assign ( ( * y * * dLdz ) . reduceNumber ( reduce : : Sum ) ) ;
2019-12-20 20:35:39 +01:00
dLdz - > applyScalarArr ( scalar : : Multiply , * x , * dLdy ) ;
2019-06-06 14:21:15 +02:00
//dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true);
}
else if ( y - > isScalar ( ) ) { // y is scalar and x is not
dLdy - > assign ( ( * x * * dLdz ) . reduceNumber ( reduce : : Sum ) ) ;
2019-12-20 20:35:39 +01:00
dLdz - > applyScalarArr ( scalar : : Multiply , * y , * dLdx ) ;
}
2019-06-06 14:21:15 +02:00
else if ( x - > isSameShape ( y ) ) {
2019-12-20 20:35:39 +01:00
x - > applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdy ) ;
y - > applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdx ) ;
2019-06-06 14:21:15 +02:00
}
else if ( x - > isSameShape ( dLdz ) ) {
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
auto yTiled = NDArray ( dLdz , false , block . launchContext ( ) ) ;
y - > tile ( yTiled ) ;
2020-05-09 07:06:14 +02:00
std : : vector < int > axesForY = ShapeUtils : : evalBroadcastBackwardAxis ( y - > shapeInfo ( ) , dLdz - > shapeInfo ( ) ) ;
2019-12-20 20:35:39 +01:00
dLdy - > assign ( ( * x * * dLdz ) . reduceAlongDimension ( reduce : : Sum , axesForY ) ) ;
yTiled . applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdx ) ;
}
2019-06-06 14:21:15 +02:00
else if ( y - > isSameShape ( dLdz ) ) {
auto xTiled = NDArray ( dLdz , false , block . launchContext ( ) ) ;
x - > tile ( xTiled ) ;
2020-05-09 07:06:14 +02:00
std : : vector < int > axesForX = ShapeUtils : : evalBroadcastBackwardAxis ( x - > shapeInfo ( ) , dLdz - > shapeInfo ( ) ) ;
2019-12-20 20:35:39 +01:00
dLdx - > assign ( ( * y * * dLdz ) . reduceAlongDimension ( reduce : : Sum , axesForX ) ) ;
xTiled . applyPairwiseTransform ( pairwise : : Multiply , * dLdz , * dLdy ) ;
2019-06-06 14:21:15 +02:00
}
else {
auto xTiled = NDArray ( dLdz , false , block . launchContext ( ) ) ;
auto yTiled = NDArray ( dLdz , false , block . launchContext ( ) ) ;
x - > tile ( xTiled ) ;
y - > tile ( yTiled ) ;
2020-05-09 07:06:14 +02:00
std : : vector < int > axesForX = ShapeUtils : : evalBroadcastBackwardAxis ( x - > shapeInfo ( ) , dLdz - > shapeInfo ( ) ) ;
std : : vector < int > axesForY = ShapeUtils : : evalBroadcastBackwardAxis ( y - > shapeInfo ( ) , dLdz - > shapeInfo ( ) ) ;
2019-12-20 20:35:39 +01:00
dLdx - > assign ( ( * y * * dLdz ) . reduceAlongDimension ( reduce : : Sum , axesForX ) ) ;
dLdy - > assign ( ( * x * * dLdz ) . reduceAlongDimension ( reduce : : Sum , axesForY ) ) ;
2019-06-06 14:21:15 +02:00
}
return Status : : OK ( ) ;
}
DECLARE_SHAPE_FN ( multiply_bp ) {
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
auto xShapeInfo = inputShape - > at ( 0 ) ;
auto yShapeInfo = inputShape - > at ( 1 ) ;
Nd4jLong * dLdxShapeInfo = nullptr ;
Nd4jLong * dLdyShapeInfo = nullptr ;
COPY_SHAPE ( xShapeInfo , dLdxShapeInfo ) ;
COPY_SHAPE ( yShapeInfo , dLdyShapeInfo ) ;
return SHAPELIST ( CONSTANT ( dLdxShapeInfo ) , CONSTANT ( dLdyShapeInfo ) ) ;
}
/*
CUSTOM_OP_IMPL ( multiply_bp , 3 , 2 , false , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto y = INPUT_VARIABLE ( 1 ) ;
auto epsNext = INPUT_VARIABLE ( 2 ) ;
auto gradX = OUTPUT_VARIABLE ( 0 ) ;
auto gradY = OUTPUT_VARIABLE ( 1 ) ;
auto lambdaX = LAMBDA_TT ( _e , _y ) {
return _e * _y ;
} ;
auto lambdaY = LAMBDA_TT ( _e , _x ) {
return _e * _x ;
} ;
if ( x - > isSameShape ( y ) ) {
// PWT case case
// X gradient
epsNext - > applyPairwiseLambda ( y , lambdaX , gradX ) ;
// Y gradient
epsNext - > applyPairwiseLambda ( x , lambdaY , gradY ) ;
} else if ( y - > isScalar ( ) ) {
// scalar case
T _y = y - > e ( 0 ) ;
auto lambdaS = LAMBDA_T ( _e , _y ) {
return _e * _y ;
} ;
T tmpX = x - > template reduceNumber < simdOps : : Sum < T > > ( ) ;
gradY - > assign ( tmpX ) ;
2019-12-20 20:35:39 +01:00
epsNext - > applyLambda ( lambdaS , * gradX ) ;
2019-06-06 14:21:15 +02:00
} else {
// broadcast case
auto preX = x - > dup ( ) ;
auto preY = y - > dup ( ) ;
auto targetShape = epsNext - > getShapeAsVector ( ) ;
preX - > tileToShape ( targetShape ) ;
preY - > tileToShape ( targetShape ) ;
auto axisX = ShapeUtils : : evalBroadcastBackwardAxis ( x - > shapeInfo ( ) , epsNext - > shapeInfo ( ) ) ;
auto axisY = ShapeUtils : : evalBroadcastBackwardAxis ( y - > shapeInfo ( ) , epsNext - > shapeInfo ( ) ) ;
if ( axisX . size ( ) > 0 ) {
auto sum = preX - > template reduceAlongDimension < simdOps : : Sum < T > > ( axisX ) ;
gradX - > assign ( sum ) ;
delete sum ;
2019-12-20 20:35:39 +01:00
} else
2019-06-06 14:21:15 +02:00
gradX - > assign ( preX ) ;
if ( axisY . size ( ) > 0 ) {
auto sum = preY - > template reduceAlongDimension < simdOps : : Sum < T > > ( axisY ) ;
gradY - > assign ( sum ) ;
delete sum ;
} else
gradY - > assign ( preY ) ;
delete preX ;
delete preY ;
}
return Status : : OK ( ) ;
}
*/
}
}
# endif