Oleh powderev (#171)
* Libnd4j: Add broadcastable elementwise power derivative #7461 first step of Pow_bp operation implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 some corrections of calculation steps Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 some bug fixes, the PowDerevative op made broadcastable, add the raw tests for op, need refactoring to use broadcast ops * Libnd4j: Add broadcastable elementwise power derivative #7461 fixed several bugs add broadcast support and tests, need to fix scalar+array and array+scalar Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 fixed bugs for scalar inputs, fixed multinomial tests, added tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 fised bugs for different shapes support, tests updated * Libnd4j: Add broadcastable elementwise power derivative #7461 applied all possible variants via tiled arrays, add support of broadcast for Pow and PowDerivative ops, covered by tests, before review have to be replaced tiled implementation by applyTrueBroadcast Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 replaced tile by broadcast implementation, fixed issue with negative x input, corrected tests, need additional testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 added and corrected test cases, corrected implementation need review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 code clean up * Libnd4j: Add broadcastable elementwise power derivative #7461 code clean up, removed some tests, add tests with scalar Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 code improvement and clean up, split tests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative #7461 some code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: Add broadcastable elementwise power derivative replace __isnanf by internal realization Signed-off-by: Oleg <oleg.semeniv@gmail.com> * pow_bp wrapper * Fixed PowBp wrapper * Tests added * Test fixed * Fix return type * Disable powBp usage * Pow backprop changed Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
6943a5f57a
commit
8fc0e63ce7
|
@ -53,6 +53,7 @@ inline pairwise::Ops fromBroadcastToPairwise(broadcast::Ops op) {
|
|||
case broadcast::LogicalXor: return pairwise::LogicalXor;
|
||||
case broadcast::LogicalNot: return pairwise::LogicalNot;
|
||||
case broadcast::LogicalAnd: return pairwise::LogicalAnd;
|
||||
case broadcast::PowDerivative: return pairwise::PowDerivative;
|
||||
default:
|
||||
throw std::runtime_error("fromBroadcastToPairwise: Not convertible operation");
|
||||
}
|
||||
|
|
|
@ -80,7 +80,8 @@
|
|||
(30, LogicalAnd), \
|
||||
(31, DivideNoNan), \
|
||||
(32, IGamma), \
|
||||
(33, IGammac)
|
||||
(33, IGammac),\
|
||||
(34, PowDerivative)
|
||||
|
||||
// these ops return same data type as input
|
||||
#define TRANSFORM_SAME_OPS \
|
||||
|
|
|
@ -52,6 +52,9 @@ namespace nd4j {
|
|||
static BroadcastOpsTuple Subtract();
|
||||
static BroadcastOpsTuple IGamma();
|
||||
static BroadcastOpsTuple IGammac();
|
||||
|
||||
static BroadcastOpsTuple Pow();
|
||||
static BroadcastOpsTuple PowDerivative();
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -51,6 +52,75 @@ namespace nd4j {
|
|||
->setAllowedInputTypes(1, {ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS});
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(Pow_bp, 3, 2, false, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto dLdz = INPUT_VARIABLE(2);
|
||||
|
||||
auto dLdx = OUTPUT_VARIABLE(0);
|
||||
auto dLdy = OUTPUT_VARIABLE(1);
|
||||
|
||||
Nd4jLong* dLdzShapeInfo = nullptr;
|
||||
const bool areShapesBroadcastable = ShapeUtils::evalBroadcastShapeInfo(x->getShapeInfo(), y->getShapeInfo(), true, dLdzShapeInfo, block.getWorkspace());
|
||||
REQUIRE_TRUE(areShapesBroadcastable, 0, "POW_BP OP: the shapes of x %s"
|
||||
" and y %s are not suitable for broadcast !",
|
||||
ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
REQUIRE_TRUE(shape::equalsSoft(dLdz->shapeInfo(), dLdzShapeInfo), 0,
|
||||
"POW_BP OP: wrong shape of next epsilon array (dLdOut),"
|
||||
" expected is %s, but got %s instead !",
|
||||
ShapeUtils::shapeAsString(dLdzShapeInfo).c_str(), ShapeUtils::shapeAsString(dLdz).c_str());
|
||||
|
||||
// dL/dy = x^y * log(x) * dL/dz
|
||||
auto temp = x->applyTrueBroadcast(BroadcastOpsTuple::Pow(), *y); // a = x^y
|
||||
x->applyTransform(transform::Log, *dLdx); // b = log(x)
|
||||
dLdx->applyScalar(nd4j::scalar::ReplaceNans, 0, *dLdx);
|
||||
temp *= *dLdx; // c = b*a
|
||||
temp *= *dLdz; // dL/dy = c * dL/dz
|
||||
if (dLdy->isSameShape(*dLdz)) {
|
||||
dLdy->assign(temp);
|
||||
}
|
||||
else {
|
||||
std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo());
|
||||
dLdy->assign(temp.reduceAlongDimension(reduce::Sum, axesForY)); // dL/dy = sum(c * dL/dz)
|
||||
}
|
||||
|
||||
// dL/dx = y*x^(y-1) * dL/dz
|
||||
x->applyTrueBroadcast(BroadcastOpsTuple::PowDerivative(), *y, temp); // a = y*x^(y-1)
|
||||
temp *= *dLdz; // dLdx = a*dL/dz
|
||||
|
||||
if (dLdx->isSameShape(*dLdz)) {
|
||||
dLdx->assign(temp); // dLdx = a*dL/dz
|
||||
}
|
||||
else {
|
||||
std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo());
|
||||
dLdx->assign(temp.reduceAlongDimension(reduce::Sum, axesForX)); // dLdx = a*dL/dz
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(Pow_bp) {
|
||||
|
||||
auto xShapeInfo = inputShape->at(0);
|
||||
auto yShapeInfo = inputShape->at(1);
|
||||
|
||||
Nd4jLong* dLdxShapeInfo = nullptr;
|
||||
Nd4jLong* dLdyShapeInfo = nullptr;
|
||||
|
||||
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
|
||||
COPY_SHAPE(yShapeInfo, dLdyShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdyShapeInfo));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(Pow_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
|
||||
->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -356,6 +356,7 @@ namespace nd4j {
|
|||
*/
|
||||
#if NOT_EXCLUDED(OP_Pow)
|
||||
DECLARE_BROADCASTABLE_OP(Pow, 0, 0);
|
||||
DECLARE_CUSTOM_OP(Pow_bp, 3, 2, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -55,4 +55,12 @@ namespace nd4j {
|
|||
return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac);
|
||||
}
|
||||
|
||||
|
||||
BroadcastOpsTuple BroadcastOpsTuple::Pow() {
|
||||
return custom(nd4j::scalar::Pow, nd4j::pairwise::Pow, nd4j::broadcast::Pow);
|
||||
}
|
||||
BroadcastOpsTuple BroadcastOpsTuple::PowDerivative() {
|
||||
return custom(nd4j::scalar::PowDerivative, nd4j::pairwise::PowDerivative, nd4j::broadcast::PowDerivative);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1280,3 +1280,334 @@ TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test1) {
|
||||
|
||||
// same shape
|
||||
NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
NDArray dLdz('c', { 2,2,2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
dLdz.assign(1.0);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdx = results->at(0);
|
||||
auto* dLdy = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExp.isSameShape(dLdx));
|
||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test2) {
|
||||
|
||||
NDArray x('c', { 1,2,3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', { 3,2,1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.assign(4.0);
|
||||
y.assign(2.0);
|
||||
dLdz.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdx = results->at(0);
|
||||
auto* dLdy = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExp.isSameShape(dLdx));
|
||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test3) {
|
||||
|
||||
// y - same shape as dLdz
|
||||
NDArray xY('c', { 1,2,3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray yY('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
xY.assign(4.0);
|
||||
yY.assign(2.0);
|
||||
dLdz.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto resultsY = op.execute({ &xY, &yY, &dLdz }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsY->status());
|
||||
|
||||
auto* dLdxY = resultsY->at(0);
|
||||
auto* dLdyY = resultsY->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY));
|
||||
ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY));
|
||||
ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY));
|
||||
ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY));
|
||||
|
||||
delete resultsY;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test4) {
|
||||
|
||||
// x - same shape ad dLdz
|
||||
NDArray yX('c', { 1,2,3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray xX('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
dLdz.linspace(0.1, 0.1);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
|
||||
xX.assign(2.0);
|
||||
yX.assign(4.0);
|
||||
|
||||
auto resultsX = op.execute({ &xX, &yX, &dLdz }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsX->status());
|
||||
|
||||
auto* dLdxX = resultsX->at(0);
|
||||
auto* dLdyX = resultsX->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX));
|
||||
ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX));
|
||||
ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX));
|
||||
ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX));
|
||||
|
||||
delete resultsX;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test5) {
|
||||
|
||||
// both single array
|
||||
NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray yConst('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdz('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdxExp('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExp('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
xConst.assign(3.0);
|
||||
yConst.assign(4.0);
|
||||
dLdz.assign(1.0);
|
||||
|
||||
dLdxExp.assign(4.0 * pow(3, 3));
|
||||
dLdyExp.assign(pow(3, 4) * log(3));
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto results = op.execute({ &xConst, &yConst, &dLdz }, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdx = results->at(0);
|
||||
auto* dLdy = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExp.isSameShape(dLdx));
|
||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||
|
||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test6) {
|
||||
|
||||
// x single array
|
||||
NDArray xConst('c', { 1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
xConst.assign(2.0);
|
||||
y.assign(4.0);
|
||||
dLdzC.linspace(0.1, 0.1);
|
||||
|
||||
NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto resultsXC = op.execute({ &xConst, &y, &dLdzC }, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsXC->status());
|
||||
|
||||
auto* dLdxXC = resultsXC->at(0);
|
||||
auto* dLdyXC = resultsXC->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC));
|
||||
ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC));
|
||||
ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC));
|
||||
ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC));
|
||||
|
||||
delete resultsXC;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test7) {
|
||||
|
||||
// Y - scalar
|
||||
auto Y = NDArrayFactory::create<float>(2.f);
|
||||
NDArray x('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdzC('c', { 2, 2, 2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
dLdzC.linspace(0.1, 0.1);
|
||||
x = 4.f;
|
||||
|
||||
NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto dLdyExpYs = NDArrayFactory::create<float>(79.85056f);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto resultsYs = op.execute({ &x, &Y, &dLdzC }, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsYs->status());
|
||||
|
||||
auto* dLdxY = resultsYs->at(0);
|
||||
auto* dLdyY = resultsYs->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY));
|
||||
ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY));
|
||||
ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY));
|
||||
ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY));
|
||||
|
||||
delete resultsYs;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test8) {
|
||||
// both scalars
|
||||
|
||||
auto X = NDArrayFactory::create<float>(4.f);
|
||||
auto Y = NDArrayFactory::create<float>(2.f);
|
||||
NDArray dLdz = NDArrayFactory::create<float>(0.1f);
|
||||
|
||||
NDArray dLdxExp = NDArrayFactory::create<float>(2.f*4.f*0.1f);
|
||||
|
||||
NDArray dLdyExp = NDArrayFactory::create<float>(pow(4.f, 2.f) * log(4.f) * 0.1f);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto results = op.execute({ &X, &Y, &dLdz }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdx = results->at(0);
|
||||
auto* dLdy = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExp.isSameShape(dLdx));
|
||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test9) {
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
// diff shapes
|
||||
NDArray x('c', { 3,2,1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', { 1,2,3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdz('c', { 3,2,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.assign(4.0);
|
||||
y.assign(2.0);
|
||||
dLdz.linspace(0.1, 0.1);
|
||||
|
||||
auto results = op.execute({ &x, &y, &dLdz }, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdx = results->at(0);
|
||||
auto* dLdy = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExp.isSameShape(dLdx));
|
||||
ASSERT_TRUE(dLdxExp.equalsTo(dLdx));
|
||||
ASSERT_TRUE(dLdyExp.isSameShape(dLdy));
|
||||
ASSERT_TRUE(dLdyExp.equalsTo(dLdy));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test10) {
|
||||
|
||||
// diff shapes broadcastable
|
||||
NDArray yB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32);
|
||||
NDArray xB('c', { 2,3,1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdzB('c', { 1,2,3,1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
dLdzB.linspace(0.1, 0.1);
|
||||
xB.assign(4.0);
|
||||
yB.assign(2.0);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
||||
|
||||
auto* dLdxB = resultsB->at(0);
|
||||
auto* dLdyB = resultsB->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB));
|
||||
ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB));
|
||||
|
||||
ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB));
|
||||
ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB));
|
||||
|
||||
delete resultsB;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
|
||||
|
||||
NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, nd4j::DataType::FLOAT32);
|
||||
NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits<float>::quiet_NaN(), -std::numeric_limits<float>::quiet_NaN() }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits<float>::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits<float>::quiet_NaN() }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
auto resultsB = op.execute({ &xB, &yB, &dLdzB }, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsB->status());
|
||||
auto* dLdxB = resultsB->at(0);
|
||||
auto* dLdyB = resultsB->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB));
|
||||
for (int i = 0; i < dLdxB->lengthOf(); ++i) {
|
||||
if (!nd4j::math::nd4j_isnan(dLdxB->e<float>(i)) && !nd4j::math::nd4j_isnan(dLdxExpB.e<float>(i)))
|
||||
ASSERT_NEAR(dLdxB->e<float>(i), dLdxExpB.e<float>(i), 0.00001);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB));
|
||||
for (int i = 0; i < dLdyB->lengthOf(); ++i) {
|
||||
if (!nd4j::math::nd4j_isnan(dLdyB->e<float>(i)) && !nd4j::math::nd4j_isnan(dLdyExpB.e<float>(i)))
|
||||
ASSERT_NEAR(dLdyB->e<float>(i), dLdyExpB.e<float>(i), 0.00001);
|
||||
}
|
||||
|
||||
delete resultsB;
|
||||
}
|
||||
|
|
|
@ -1090,7 +1090,7 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
// multinomial as binomial if 2 classes used
|
||||
int batchValue = 1;
|
||||
int ClassValue = 2;
|
||||
int Samples = 1000000;
|
||||
int Samples = 100000;
|
||||
|
||||
NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32);
|
||||
|
||||
|
@ -1107,8 +1107,8 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
auto mean = output.meanNumber();
|
||||
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||
// theoretical values for binomial
|
||||
ASSERT_NEAR(0.5, deviation.e<double>(0), 3e-3);
|
||||
ASSERT_NEAR(0.5, mean.e<double>(0), 3e-3);
|
||||
ASSERT_NEAR(0.5, deviation.e<double>(0), 4e-3); // 1000000 3e-3);
|
||||
ASSERT_NEAR(0.5, mean.e<double>(0), 4e-3); // 1000000 3e-3);
|
||||
|
||||
for (int i = 0; i < output.lengthOf(); i++) {
|
||||
auto value = output.e<Nd4jLong>(i);
|
||||
|
@ -1122,8 +1122,8 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
mean = outputR->meanNumber();
|
||||
// printf("Random seed - Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||
ASSERT_NEAR(0.5, deviation.e<double>(0), 35e-3);
|
||||
ASSERT_NEAR(0.5, mean.e<double>(0), 35e-3);
|
||||
ASSERT_NEAR(0.5, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
ASSERT_NEAR(0.5, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
|
||||
for (int i = 0; i < outputR->lengthOf(); i++) {
|
||||
auto value = outputR->e<Nd4jLong>(i);
|
||||
|
@ -1138,7 +1138,7 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
|
||||
int batchValue = 1;
|
||||
int ClassValue = 5;
|
||||
int Samples = 1000000;
|
||||
int Samples = 100000;
|
||||
|
||||
NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32);
|
||||
|
||||
|
@ -1165,14 +1165,14 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
auto c = countsR.e<double>(i);
|
||||
auto p = probExpect.e<double>(i);
|
||||
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
|
||||
ASSERT_NEAR((c / Samples), p, 35e-3);
|
||||
ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3);
|
||||
}
|
||||
|
||||
auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
auto mean = outputR->meanNumber();
|
||||
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 35e-3);
|
||||
ASSERT_NEAR(2.906, mean.e<double>(0), 35e-3);
|
||||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
ASSERT_NEAR(2.906, mean.e<double>(0), 45e-3); // 1000000 35e-3);
|
||||
|
||||
delete resultR;
|
||||
|
||||
|
@ -1195,12 +1195,12 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
auto c = counts.e<double>(i);
|
||||
auto p = probExpect.e<double>(i);
|
||||
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
|
||||
ASSERT_NEAR((c / Samples), p, 3e-3);
|
||||
ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3);
|
||||
}
|
||||
|
||||
deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
mean = output.meanNumber();
|
||||
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 3e-3);
|
||||
ASSERT_NEAR(2.906, mean.e<double>(0), 3e-3);
|
||||
ASSERT_NEAR(1.2175, deviation.e<double>(0), 5e-3); // 1000000 3e-3);
|
||||
ASSERT_NEAR(2.906, mean.e<double>(0), 5e-3); // 1000000 3e-3);
|
||||
}
|
||||
|
|
|
@ -92,20 +92,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
|
|||
import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bp.*;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
|
||||
|
@ -1420,6 +1407,10 @@ public class DifferentialFunctionFactory {
|
|||
return new PowDerivative(sameDiff(), iX, false, pow).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) {
|
||||
return new PowBp(sameDiff(), x, pow, gradient).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable mishDerivative(SDVariable iX) {
|
||||
return new MishDerivative(sameDiff(), iX, false).outputVariable();
|
||||
}
|
||||
|
|
|
@ -230,6 +230,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.scalar.LogX.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.Pow.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.bp.PowBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
package org.nd4j.linalg.api.ops.impl.reduce.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.BaseArithmeticBackpropOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class PowBp extends BaseDynamicTransformOp {
|
||||
|
||||
public PowBp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable dLdz) {
|
||||
super(sameDiff,new SDVariable[]{x,y,dLdz}, false);
|
||||
}
|
||||
|
||||
public PowBp(INDArray x, INDArray y, INDArray dLdz,
|
||||
INDArray dLdx, INDArray dLdy) {
|
||||
super(new INDArray[]{x,y, dLdz}, new INDArray[]{dLdx, dLdy});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "Pow_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isInplaceCall() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes);
|
||||
//Gradient types: same as input
|
||||
return Arrays.asList(arg(0).dataType(), arg(1).dataType());
|
||||
}
|
||||
}
|
|
@ -19,7 +19,9 @@ package org.nd4j.linalg.api.ops.impl.scalar;
|
|||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
|
@ -29,6 +31,7 @@ import org.tensorflow.framework.GraphDef;
|
|||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -93,5 +96,4 @@ public class Pow extends BaseScalarOp {
|
|||
SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0));
|
||||
return Arrays.asList(g);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -62,11 +62,14 @@ public class Pow extends DynamicCustomOp {
|
|||
//dL/da = b*a^(b-1) * dL/dy
|
||||
//dL/db = a^b * log(a) * dL/dy
|
||||
|
||||
SDVariable a = arg(0);
|
||||
/*SDVariable a = arg(0);
|
||||
SDVariable b = arg(1);
|
||||
SDVariable dlda = b.mul(sameDiff.math().pow(a,b.sub(1))).mul(f1.get(0));
|
||||
SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0));
|
||||
return Arrays.asList(dlda, dldb);
|
||||
return Arrays.asList(dlda, dldb);*/
|
||||
|
||||
SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0));
|
||||
return Arrays.asList(g);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -921,4 +921,60 @@ public class ReductionBpOpValidation extends BaseOpValidation {
|
|||
assertNull(err, err);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPowBP() {
|
||||
|
||||
for (boolean keepDims : new boolean[]{false, true}) {
|
||||
|
||||
INDArray preReduceInput_1 = Nd4j.createFromArray(new double[]{
|
||||
4,3,2,5,7,8,-9,-12
|
||||
}).reshape(2,2,2);
|
||||
INDArray preReduceInput_2 = Nd4j.createFromArray(new double[]{
|
||||
2,3,-2,4,-1,-4,10,8
|
||||
}).reshape(2,2,2);
|
||||
INDArray preReduceInput_3 = Nd4j.linspace(1, 8, 8).reshape(2, 2,2);
|
||||
INDArray gradOutput = Nd4j.valueArrayOf(new long[]{2, 2, 2}, 1.0);
|
||||
INDArray dLdInExpected_1 = Nd4j.createFromArray(new double[]{
|
||||
8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08
|
||||
}).reshape(2,2,2);
|
||||
INDArray dLdInExpected_2 = Nd4j.createFromArray(new double[]{
|
||||
22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0
|
||||
}).reshape(2,2,2);
|
||||
INDArray output1 = Nd4j.createUninitialized(2, 2,2);
|
||||
INDArray output2 = Nd4j.createUninitialized(2, 2,2);
|
||||
|
||||
String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2,
|
||||
gradOutput, output1, output2))
|
||||
.expectedOutput(0, dLdInExpected_1).expectedOutput(1, dLdInExpected_2));
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPowBP1() {
|
||||
|
||||
INDArray preReduceInput_1 = Nd4j.createFromArray(new float[]{
|
||||
0.0714f, 0.4735f, -0.1249f, 0.4482f,
|
||||
-0.1376f, 0.5218f, 0.5558f, 0.2444f,
|
||||
-0.5297f, 0.4291f, 0.4913f, -0.1178f
|
||||
}).reshape(3,4);
|
||||
INDArray preReduceInput_2 = Nd4j.scalar(2.0000f);
|
||||
|
||||
INDArray gradOutput = Nd4j.valueArrayOf(new long[]{3, 4}, 1.0f);
|
||||
|
||||
INDArray output1 = Nd4j.createUninitialized(DataType.FLOAT, 3,4);
|
||||
INDArray output2 = Nd4j.scalar(DataType.FLOAT, 1.0); //Nd4j.createUninitialized(DataType.FLOAT, 3,4);
|
||||
|
||||
INDArray expected1 = Nd4j.createFromArray(new float[]{
|
||||
0.1428f, 0.9470f, -0.2498f, 0.8964f,
|
||||
-0.2752f, 1.0436f, 1.1116f, 0.4888f,
|
||||
-1.0594f, 0.8582f, 0.9826f, -0.2356f
|
||||
}).reshape(3,4);
|
||||
INDArray expected2 = Nd4j.scalar(DataType.FLOAT, -1.112316132);
|
||||
String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2,
|
||||
gradOutput, output1, output2)).expectedOutput(0, expected1).expectedOutput(1, expected2));
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue