From bdc3eacafde4026371b7097fc04a537f46662d73 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 30 Aug 2019 20:13:01 +0300 Subject: [PATCH] one small playground test Signed-off-by: raver119 --- .../declarable/helpers/cpu/legacy_helper.cpp | 21 +++++++++++++-- .../layers_tests/PlaygroundTests.cpp | 26 +++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 45024b5cb..d673e64bd 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -39,11 +39,28 @@ namespace helpers { template static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f ? y : T(0.f); + + T zero = (T) 0.f; + auto functor = LAMBDA_TT(x, y, zero){ + return x > zero ? y : zero; }; input->applyPairwiseLambda(epsilon, functor, output); + + /* + auto x = input->bufferAsT(); + auto y = epsilon->bufferAsT(); + auto z = output->bufferAsT(); + + int length = input->lengthOf(); + + T zero = (T) 0.f; + + PRAGMA_OMP_PARALLEL_FOR + for (int e = 0; e < length; e++) { + z[e] = x[e] > zero ? y[e] : zero; + } + */ } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 2d9a23e59..b76538afd 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,8 @@ #include #include +#include + using namespace nd4j; using namespace nd4j::graph; @@ -55,3 +58,26 @@ public: } }; +/* +TEST_F(PlaygroundTests, test_relubp_1) { + auto x = NDArrayFactory::create('c', {128, 64, 224, 224}); + auto y = x.ulike(); + auto z = x.ulike(); + RandomGenerator rng(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0); + + int iterations = 10; + + auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < iterations; e++) + ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + auto time = (Nd4jLong) outerTime / iterations; + auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024; + + nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); +} +*/