- add parameter alpha to elu and lrelu_bp (#213)

* - add parameter alpha to elu and lrelu_bp

Signed-off-by: Yurii <yurii@skymind.io>

* - forgot to correct header activations.h

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-31 20:57:39 +03:00 committed by raver119
parent b71c993ded
commit a35926c6e9
11 changed files with 97 additions and 98 deletions

View File

@ -116,7 +116,6 @@
#define TRANSFORM_STRICT_OPS \ #define TRANSFORM_STRICT_OPS \
(3, ELUDerivative), \
(4, TanhDerivative), \ (4, TanhDerivative), \
(5, HardTanhDerivative), \ (5, HardTanhDerivative), \
(6, SigmoidDerivative), \ (6, SigmoidDerivative), \
@ -148,7 +147,6 @@
(32, ATan), \ (32, ATan), \
(33, HardTanh), \ (33, HardTanh), \
(34, SoftSign), \ (34, SoftSign), \
(35, ELU), \
(36, HardSigmoid), \ (36, HardSigmoid), \
(37, RationalTanh) ,\ (37, RationalTanh) ,\
(38, RectifiedTanh) ,\ (38, RectifiedTanh) ,\
@ -211,6 +209,8 @@
(4, ReverseDivide),\ (4, ReverseDivide),\
(5, ReverseSubtract),\ (5, ReverseSubtract),\
(6, MaxPairwise),\ (6, MaxPairwise),\
(7, ELU), \
(8, ELUDerivative), \
(13, MinPairwise),\ (13, MinPairwise),\
(14, CopyPws),\ (14, CopyPws),\
(15, Mod),\ (15, Mod),\

View File

@ -25,12 +25,14 @@
#include <ops/declarable/helpers/legacy_helpers.h> #include <ops/declarable/helpers/legacy_helpers.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(elu, 1, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::ELU, output, nullptr); const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
STORE_RESULT(output);
input->applyScalar(nd4j::scalar::ELU, alpha, output);
return Status::OK(); return Status::OK();
} }
@ -41,14 +43,18 @@ namespace nd4j {
->setAllowedOutputTypes(0, {ALL_FLOATS}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto epsilon = INPUT_VARIABLE(1); auto epsilon = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
// input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output);
helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha);
//input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, z, nullptr);
helpers::eluDerivative(block.launchContext(), input, epsilon, z);
return Status::OK(); return Status::OK();
} }

View File

@ -25,13 +25,13 @@
#include <ops/declarable/helpers/legacy_helpers.h> #include <ops/declarable/helpers/legacy_helpers.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, 1, 0) { CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
float t = block.numT() > 0 ? T_ARG(0) : 0.0f; float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
input->applyScalar(nd4j::scalar::LeakyRELU, t, output); input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
@ -43,14 +43,16 @@ namespace nd4j {
->setAllowedOutputTypes(0, {ALL_FLOATS}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto epsilon = INPUT_VARIABLE(1); auto epsilon = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
//input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr); //input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr);
helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z); helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha);
return Status::OK(); return Status::OK();
} }

View File

@ -82,8 +82,8 @@ namespace nd4j {
* Math is: x < 0 ? alpha * x : x; * Math is: x < 0 ? alpha * x : x;
*/ */
#if NOT_EXCLUDED(OP_lrelu) #if NOT_EXCLUDED(OP_lrelu)
DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0);
DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0);
#endif #endif
/** /**
@ -91,8 +91,8 @@ namespace nd4j {
* Math is: x >= 0 ? x : exp(x) - 1; * Math is: x >= 0 ? x : exp(x) - 1;
*/ */
#if NOT_EXCLUDED(OP_elu) #if NOT_EXCLUDED(OP_elu)
DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0);
DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0);
#endif #endif
/** /**

View File

@ -81,29 +81,35 @@ namespace helpers {
} }
template <typename T> template <typename T>
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return x >= (T)0.f? y : T(0.f); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT) {
return x < 0 ? alphaT * y : y;
}; };
input->applyPairwiseLambda<T>(epsilon, functor, output); input->applyPairwiseLambda<T>(epsilon, functor, output);
} }
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return y * nd4j::math::nd4j_eluderivative<T,T>(x); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT){
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
}; };
input->applyPairwiseLambda<T>(epsilon, functor, output); input->applyPairwiseLambda<T>(epsilon, functor, output);
} }
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>

View File

@ -66,29 +66,35 @@ namespace nd4j {
} }
template <typename T> template <typename T>
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return x >= (T)0.f? y : T(0.f); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT) {
return x < 0 ? alphaT * y : y;
}; };
input->applyPairwiseLambda(epsilon, functor, output); input->applyPairwiseLambda(epsilon, functor, output);
} }
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return y * nd4j::math::nd4j_eluderivative<T,T>(x); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT){
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
}; };
input->applyPairwiseLambda(epsilon, functor, output); input->applyPairwiseLambda(epsilon, functor, output);
} }
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>

View File

@ -46,8 +46,8 @@ namespace helpers {
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond);
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);

View File

@ -2271,26 +2271,26 @@ namespace simdOps {
} }
}; };
template <typename X> template <typename X, typename Y, typename Z>
class ELU { class ELU {
public: public:
no_op_exec_special_same no_op_exec_special_same
no_op_exec_special_same_cuda no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) { op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_elu<X,X>(d1); return nd4j::math::nd4j_elu<X,Z>(d1, static_cast<X>(d2));
} }
}; };
template <typename X> template <typename X, typename Y, typename Z>
class ELUDerivative { class ELUDerivative {
public: public:
no_op_exec_special_same no_op_exec_special_same
no_op_exec_special_same_cuda no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) { op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_eluderivative<X,X>(d1); return nd4j::math::nd4j_eluderivative<X,Z>(d1, static_cast<X>(d2));
} }
}; };

View File

@ -130,13 +130,12 @@ namespace nd4j {
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_elu(T val) { math_def inline Z nd4j_elu(T val, T alpha) {
if (val >= (T) 0.f) return val; if (val >= (T) 0.f)
else return nd4j_exp<T, Z>(val) - (Z) 1.0f; return val;
//return val >= 0.0 ? val : (nd4j_exp<T>(val) - 1.0); return static_cast<Z>(alpha) * (nd4j_exp<T, Z>(val) - static_cast<Z>(1.0f));
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_leakyrelu(T val,T alpha) { math_def inline Z nd4j_leakyrelu(T val,T alpha) {
if (val < (T) 0.0f) if (val < (T) 0.0f)
@ -145,13 +144,14 @@ namespace nd4j {
return val; return val;
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_eluderivative(T val) { math_def inline Z nd4j_eluderivative(T val, T alpha) {
if (val >= (T) 0.0f) return (Z) 1.0f; if (val >= static_cast<T>(0.0f))
else return nd4j_exp<T, Z>(val); return static_cast<Z>(1.0f);
return static_cast<Z>(alpha) * nd4j_exp<T, Z>(val);
//return val >= 0.0 ? 1.0 : nd4j_exp(val); //return val >= 0.0 ? 1.0 : nd4j_exp(val);
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_sin(T val); math_def inline Z nd4j_sin(T val);

View File

@ -2794,53 +2794,42 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
TEST_F(DeclarableOpsTests3, elu_test1) { TEST_F(DeclarableOpsTests3, elu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
// auto expS = NDArrayFactory::create<double>('c', {3}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, -0.32968, -0.393469, -0.451188, .7, .8, .9});
nd4j::ops::elu op; nd4j::ops::elu op;
auto results = op.execute({&x}, {}, {}); auto results = op.execute({&x}, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("ELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, elu_test2) { TEST_F(DeclarableOpsTests3, elu_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
auto eps = NDArrayFactory::create<double>('c', {3,3}); auto eps = NDArrayFactory::create<double>('c', {3,3});
eps.assign(2.); eps.assign(2.);
// auto expU = NDArrayFactory::create<double>('c', {3,3}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 1.34064, 1.213061, 1.097623, 2, 2, 2});
nd4j::ops::elu_bp op; nd4j::ops::elu_bp op;
auto results = op.execute({ &x, &eps }, {}, {}); auto results = op.execute({ &x, &eps }, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1); ASSERT_TRUE(exp.equalsTo(s));
// auto v = results->at(2);
// s->printIndexedBuffer("ELU_BP");
ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, lrelu_test1) { TEST_F(DeclarableOpsTests3, lrelu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
nd4j::ops::lrelu op; nd4j::ops::lrelu op;
@ -2849,20 +2838,16 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("LRELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
TEST_F(DeclarableOpsTests3, lrelu_test2) { TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0, 0, 0, 2, 2, 2}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
nd4j::ops::lrelu_bp op; nd4j::ops::lrelu_bp op;
auto results = op.execute({&x, &eps}, {0.2}, {}); auto results = op.execute({&x, &eps}, {0.2}, {});
@ -2870,9 +2855,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("LRELU_BP");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
@ -2882,8 +2864,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
TEST_F(DeclarableOpsTests3, selu_test1) { TEST_F(DeclarableOpsTests3, selu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
nd4j::ops::selu op; nd4j::ops::selu op;
@ -2892,7 +2872,6 @@ TEST_F(DeclarableOpsTests3, selu_test1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// s->printIndexedBuffer("SELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;

View File

@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) {
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.});
auto res = NDArrayFactory::create<double>('c', {2, 2, 2}); auto res = NDArrayFactory::create<double>('c', {2, 2, 2});
input.applyTransform(transform::ELU, &res); input.applyScalar(nd4j::scalar::ELU, 1.f, &res);
ASSERT_TRUE(res.equalsTo(&exp)); ASSERT_TRUE(res.equalsTo(&exp));
} }