diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 4b1f3448f..0e5200321 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -323,7 +323,9 @@ (11, TruncatedNormalDistribution) ,\ (12, AlphaDropOut),\ (13, ExponentialDistribution),\ - (14, ExponentialDistributionInv) + (14, ExponentialDistributionInv), \ + (15, PoissonDistribution), \ + (16, GammaDistribution) #define PAIRWISE_INT_OPS \ (0, ShiftLeft), \ diff --git a/libnd4j/include/ops/declarable/generic/random/gamma.cpp b/libnd4j/include/ops/declarable/generic/random/gamma.cpp new file mode 100644 index 000000000..672eba422 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/random/gamma.cpp @@ -0,0 +1,83 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_random_gamma) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto alpha = INPUT_VARIABLE(1); + NDArray* beta = nullptr; + + if (block.width() > 2) { + beta = INPUT_VARIABLE(2); + REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, "random_gamma: alpha and beta shapes should be broadcastable."); + } + + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; + + if (block.getIArguments()->size()) { + seed = INT_ARG(0); + } + + rng.setSeed(seed); + + helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(random_gamma) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto alphaShape = inputShape->at(1); + auto additionalShape = alphaShape; + if (inputShape->size() > 2) { + auto rest = inputShape->at(2); additionalShape = nullptr; + REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, "random_gamma: alpha and beta shapes should be broadcastable."); + ShapeUtils::evalBroadcastShapeInfo(alphaShape, rest, true, additionalShape, block.workspace()); + } + auto lastDim = shape::sizeAt(alphaShape, 0); + auto dtype = ArrayOptions::dataType(alphaShape); + for (auto i = 0; i < shape::rank(additionalShape); i++) + shape.push_back(shape::sizeAt(additionalShape, i)); + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); + } + + DECLARE_TYPES(random_gamma) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedInputTypes(2, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/random/poisson.cpp b/libnd4j/include/ops/declarable/generic/random/poisson.cpp new file mode 100644 index 000000000..935bed095 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/random/poisson.cpp @@ -0,0 +1,67 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_random_poisson) + +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) { + // gamma distribution + auto rng = block.randomGenerator(); + auto shape = INPUT_VARIABLE(0); + auto lambda = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + auto seed = 0; + if (block.getIArguments()->size()) { + seed = INT_ARG(0); + } + rng.setSeed(seed); + helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output); + + return Status::OK(); + } + + + DECLARE_SHAPE_FN(random_poisson) { + auto in = INPUT_VARIABLE(0); + auto shape = in->template asVectorT(); + auto lambdaShape = inputShape->at(1); + auto dtype = ArrayOptions::dataType(lambdaShape); + for (auto d = 0; d < shape::rank(lambdaShape); ++d ) { + shape.emplace_back(shape::sizeAt(lambdaShape, d)); + } + auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape); + return SHAPELIST(newShape); + } + + DECLARE_TYPES(random_poisson) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_FLOATS}); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index 224db2dae..333fcc089 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -49,7 +49,23 @@ namespace nd4j { DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0); #endif + #if NOT_EXCLUDED(OP_random_crop) DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0); + #endif + + /** + * random_gamma op. + */ + #if NOT_EXCLUDED(OP_random_gamma) + DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0); + #endif + + /** + * random_poisson op. + */ + #if NOT_EXCLUDED(OP_random_poisson) + DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0); + #endif } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp new file mode 100644 index 000000000..5bbf618ef --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +//#include +#include +//#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + template + void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { + + Nd4jLong* broadcasted = nullptr; + if (beta != nullptr) + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace()); + else + broadcasted = alpha->shapeInfo(); + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() / step; + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context); + + copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); + copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + + } +// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c'; + bool directOutput = output->ews() == 1 && output->ordering() == 'c'; + T* outputBuf = output->dataBuffer()->primaryAsT(); + + PRAGMA_OMP_PARALLEL_FOR + for (auto k = 0; k < shift; k++) { + auto pos = k * step; + auto u = rng.relativeT(k, 0., 1.); + for (auto e = 0; e < step; e++) + if (directOutput) { + outputBuf[pos + e] = math::nd4j_igamma(copyAlpha->t(e), + beta != nullptr ? copyBeta->t(e) * u : u); + } + else { + output->t(pos + e) = math::nd4j_igamma(copyAlpha->t(e), + beta != nullptr ? copyBeta->t(e) * u : u); + } + } + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + //delete broadcasted; + } + } + + void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); + } + BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, + graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); + + /* + * algorithm Poisson generator based upon the inversion by sequential search:[48]:505 + init: + Let x ← 0, p ← e−λ, s ← p. + Generate uniform random number u in [0,1]. + while u > s do: + x ← x + 1. + p ← p * λ / x. + s ← s + p. + return x. + * */ + template + void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + auto step = lambda->lengthOf(); + T* lambdaBuf = lambda->dataBuffer()->primaryAsT(); + T* outputBuf = output->dataBuffer()->primaryAsT(); + bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c'; + bool directOut = output->ews() == 1 && output->ordering() == 'c'; + PRAGMA_OMP_PARALLEL_FOR + for (auto k = 0; k < shift; k++) { + auto pos = k * step; + auto u = rng.relativeT(k, 0., 1.); + for (auto e = 0; e < step; e++) { + auto p = math::nd4j_exp(-lambda->t(e)); + auto s = p; + auto x = T(0.f); + while (u > s) { + x += 1.f; + p *= directLa?lambdaBuf[e]/x:lambda->t(e) / x; + s += p; + } + if (directOut) + outputBuf[pos + e] = x; + else + output->t(pos + e) = x; + } + } + } + + void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); + } + BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, + graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES); + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu new file mode 100644 index 000000000..e1f8645b8 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -0,0 +1,186 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// + +#include +//#include +#include +#include +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + /* + * fillGammaKernel - fill up output with gamma distributed values + * + * uList - uniformly distributed values set + * uLength - length of uList + * alpha - alpha param + * beta - beta param + * output - distributed output. + * */ + template + static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, Nd4jLong* alphaShape, + T* beta, Nd4jLong* betaShape, T* output, Nd4jLong* outputShape) { + // fill up + __shared__ Nd4jLong aLength; + if (threadIdx.x == 0) { + aLength = shape::length(alphaShape); + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + auto pos = k * aLength; + auto u = uList[k]; // this is a vector + for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) { + auto aIndex = shape::getIndexOffset(e, alphaShape); + auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL; + auto betaV = T(beta != nullptr ? beta[bIndex] * u : u); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + + output[zIndex] = math::nd4j_igamma(alpha[aIndex], betaV); + } + } + } + + template + static void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { + // To fill up output need to broadcast alpha and beta to the same shape and in + Nd4jLong* broadcasted = nullptr; + if (beta != nullptr) + ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace()); + else + broadcasted = alpha->shapeInfo(); + auto step = shape::length(broadcasted); + auto shift = output->lengthOf() / step; + + auto copyAlpha = alpha; + auto copyBeta = beta; + if (beta != nullptr) { + NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context); + NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context); + + copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha)); + copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta)); + copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice(); + } + + auto stream = context->getCudaStream(); + NDArray uniform = NDArrayFactory::create('c', {shift}, context); + uniform.syncToDevice(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); + + fillGammaKernel<<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT(), shift, + copyAlpha->dataBuffer()->specialAsT(), copyAlpha->specialShapeInfo(), + beta?copyBeta->dataBuffer()->specialAsT():(T*)nullptr, + beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr, + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); + + if (beta != nullptr) { + delete copyAlpha; + delete copyBeta; + //delete broadcasted; + } + + } + + void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) { + if (beta) + NDArray::prepareSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE); + if (beta) + NDArray::registerSpecialUse({output}, {alpha, beta}); + else + NDArray::prepareSpecialUse({output}, {alpha}); + } + BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE); + + + /* + * algorithm Poisson generator based upon the inversion by sequential search + * + init: + Let x ← 0, p ← e−λ, s ← p. + using uniformly random sequence U (u in U) distributed at [0, 1]. + while u > s do: + x ← x + 1. + p ← p * λ / x. + s ← s + p. + return x. + * */ + template + static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, Nd4jLong* lambdaShape, T* output, + Nd4jLong* outputShape) { + + __shared__ Nd4jLong step; + + if (threadIdx.x == 0) { + step = shape::length(lambdaShape); + } + __syncthreads(); + + for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) { + auto pos = k * step; + auto u = uList[k]; + for (auto e = threadIdx.x; e < step; e += blockDim.x) { + auto p = math::nd4j_exp(-lambda[e]); + auto s = p; + auto x = T(0.f); + auto lIndex = shape::getIndexOffset(e, lambdaShape); + auto zIndex = shape::getIndexOffset(e + pos, outputShape); + while (u > s) { + x += T(1.); + p *= lambda[lIndex] / x; + s += p; + } + output[zIndex] = x; + } + } + } + + template + static void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { + auto shift = output->lengthOf() / lambda->lengthOf(); + NDArray uniform('c', {shift}, output->dataType()); + auto stream = context->getCudaStream(); + // fill up uniform with given length + RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.); + fillPoissonKernel<<<128, 256, 128, *stream>>>(uniform.dataBuffer()->specialAsT(), uniform.lengthOf(), + lambda->dataBuffer()->specialAsT(), lambda->specialShapeInfo(), + output->dataBuffer()->specialAsT(), output->specialShapeInfo()); + } + + void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) { + NDArray::prepareSpecialUse({output}, {lambda}); + BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE); + NDArray::registerSpecialUse({output}, {lambda}); + } + + BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE); +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/random.h b/libnd4j/include/ops/declarable/helpers/random.h new file mode 100644 index 000000000..a4603c0bd --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/random.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author sgazeos@gmail.com +// +// +// Declaration of distribution helpers +// +#ifndef __RANDOM_HELPERS__ +#define __RANDOM_HELPERS__ +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + + void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); + void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); + +} +} +} +#endif diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index 1d5e89792..8eb25c84c 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -129,6 +129,47 @@ namespace randomOps { } }; + template + class PoissonDistribution { + public: + no_exec_special + no_exec_special_cuda + + method_XY + + random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max() / 10 , nd4j::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igammac(nd4j::math::nd4j_floor(x), lambda); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) { + T lambda = extraParams[0]; + return valueX <= (T)0.f ? (T)0.f : (T)nd4j::math::nd4j_igammac(nd4j::math::nd4j_floor(valueX), lambda); + } + }; + + template + class GammaDistribution { + public: + no_exec_special + no_exec_special_cuda + + method_XY + + random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max() / 10 , nd4j::DataTypeUtils::template max() / 10); + return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma(alpha, x * beta); + } + + random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) { + T alpha = extraParams[0]; + T beta = extraParams[1]; + return valueX <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma(alpha, beta * valueX); + } + }; /** * Basic DropOut/DropConnect Op diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index d0af6c8ed..f40591e17 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -894,6 +894,10 @@ namespace nd4j { Z aim = nd4j_pow(x, a) / (nd4j_exp(x) * nd4j_gamma(a)); auto sum = Z(0.); auto denom = Z(1.); + if (a <= X(0.000001)) + //throw std::runtime_error("Cannot calculate gamma for a zero val."); + return Z(0); + for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) { denom *= (a + i); sum += nd4j_pow(x, i) / denom; diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index e1a23ee3f..29c1d5214 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -773,6 +773,88 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) { delete result; } +TEST_F(RNGTests, Test_PoissonDistribution_1) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto la = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + la.linspace(1.0); + + + nd4j::ops::random_poisson op; + auto result = op.execute({&x, &la}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Poisson distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + delete result; +} + +TEST_F(RNGTests, Test_GammaDistribution_1) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + + + nd4j::ops::random_gamma op; + auto result = op.execute({&x, &al}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + delete result; +} + +TEST_F(RNGTests, Test_GammaDistribution_2) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto be = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + be.assign(1.0); + + nd4j::ops::random_gamma op; + auto result = op.execute({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + delete result; +} + +TEST_F(RNGTests, Test_GammaDistribution_3) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {3, 1}); + auto be = NDArrayFactory::create('c', {1, 2}); + auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); + + al.linspace(1.0); + be.assign(2.0); + + nd4j::ops::random_gamma op; + auto result = op.execute({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + delete result; +} + namespace nd4j { namespace tests { static void fillList(Nd4jLong seed, int numberOfArrays, std::vector &shape, std::vector &list, nd4j::graph::RandomGenerator *rng) {