one small playground test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-30 20:13:01 +03:00
parent 6efffb727f
commit bdc3eacafd
2 changed files with 45 additions and 2 deletions

View File

@ -39,11 +39,28 @@ namespace helpers {
template <typename T>
static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
auto functor = LAMBDA_TT(x, y){
return x > (T)0.f ? y : T(0.f);
T zero = (T) 0.f;
auto functor = LAMBDA_TT(x, y, zero){
return x > zero ? y : zero;
};
input->applyPairwiseLambda<T>(epsilon, functor, output);
/*
auto x = input->bufferAsT<T>();
auto y = epsilon->bufferAsT<T>();
auto z = output->bufferAsT<T>();
int length = input->lengthOf();
T zero = (T) 0.f;
PRAGMA_OMP_PARALLEL_FOR
for (int e = 0; e < length; e++) {
z[e] = x[e] > zero ? y[e] : zero;
}
*/
}
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {

View File

@ -32,6 +32,7 @@
#include <GradCheck.h>
#include <ops/declarable/helpers/im2col.h>
#include <Loops.h>
#include <RandomLauncher.h>
#include <helpers/BenchmarkHelper.h>
#include <ops/declarable/helpers/scatter.h>
@ -41,6 +42,8 @@
#include <performance/benchmarking/FullBenchmarkSuit.h>
#include <performance/benchmarking/LightBenchmarkSuit.h>
#include <ops/declarable/helpers/legacy_helpers.h>
using namespace nd4j;
using namespace nd4j::graph;
@ -55,3 +58,26 @@ public:
}
};
/*
TEST_F(PlaygroundTests, test_relubp_1) {
auto x = NDArrayFactory::create<float>('c', {128, 64, 224, 224});
auto y = x.ulike();
auto z = x.ulike();
RandomGenerator rng(119, 120);
RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0);
RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0);
int iterations = 10;
auto timeStart = std::chrono::system_clock::now();
for (int e = 0; e < iterations; e++)
ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z);
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
auto time = (Nd4jLong) outerTime / iterations;
auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024;
nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw);
}
*/