diff --git a/libnd4j/include/ops/declarable/generic/random/multinomial.cpp b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp new file mode 100644 index 000000000..bbdee17f4 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/random/multinomial.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#if NOT_EXCLUDED(OP_random_multinomial) + +#include +#include +#include + +namespace nd4j { + namespace ops { + /////////////////////// + /** + * multinomial (categorical) random generator + * takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)] + * and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N. + * represents the unnormalized log-probabilities for all classes. + * Int arguments: 0 - optional argument, corresponds to dimension with batch_size + * Int arguments: 1 - optional argument, integer type to use for the output. Default int64. + */ + // used https://en.wikipedia.org/wiki/Categorical_distribution + // methods: gumbel trick + softmax + argmax + CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + + REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. "); + + REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + // do nothing if number of samples = 0 + if (0 == numOfSamples) + return Status::OK(); + + REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + auto dimA = (0 == dimC) ? 1 : 0; + if (1 == input->sizeAt(dimA)) { + *output = 0; + return Status::OK(); + } + + auto rng = block.randomGenerator(); + helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC); + return Status::OK(); + } + + + DECLARE_SHAPE_FN(random_multinomial) { + + auto input = INPUT_VARIABLE(0); + auto inputSamples = INPUT_VARIABLE(1); + + REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample," + " but got no argumets instead."); + + Nd4jLong numOfSamples = static_cast(inputSamples->e(0)); + + REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples); + + const int rank = input->rankOf(); + REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + auto nShape = input->getShapeAsVector(); + auto dimA = (0 == dimC) ? 1 : 0; + nShape[dimA] = numOfSamples; + + DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast(INT_ARG(1)) : nd4j::DataType::INT64) : nd4j::DataType::INT64; + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape)); + } + + DECLARE_TYPES(random_multinomial) { + getOpDescriptor() + ->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS }) + ->setAllowedInputTypes(1, { nd4j::DataType::INT32 }) + ->setAllowedOutputTypes(0, { ALL_INDICES }); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index a361c8fde..f52534411 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -49,6 +49,22 @@ namespace nd4j { #if NOT_EXCLUDED(OP_randomuniform) DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); #endif + /* + * multinomial (categorical) random generator draws samples from a multinomial distribution + * + * Input array: + * 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)] + * 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N. + * Int arguments: + * 0 - optional argument, corresponds to dimension with batch_size + * 1 - optional argument, integer type to use for the output. Default int64. + * + * Output array: + * 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples] + */ + #if NOT_EXCLUDED(OP_random_multinomial) + DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0); + #endif #if NOT_EXCLUDED(OP_random_normal) DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp index f25859b1c..ad04db307 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/random.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/random.cpp @@ -24,6 +24,8 @@ //#include #include #include +#include +#include namespace nd4j { namespace ops { @@ -150,6 +152,61 @@ namespace helpers { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); } + + // used https://en.wikipedia.org/wiki/Categorical_distribution + // methods: gumbel trick + softmax + argmax + template + void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + + const Tx* x = input.bufferAsT(); + Tz* z = output.bufferAsT(); + + Tx minVal = DataTypeUtils::min(); + Tx maxVal = 1.0; + + auto dimA = (0 == dimC) ? 1 : 0; + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const Nd4jLong zDimAstride = output.stridesOf()[dimA]; + const Nd4jLong xDimAstride = input.stridesOf()[dimA]; + const Nd4jLong zDimCstride = output.stridesOf()[dimC]; + const Nd4jLong xDimCstride = input.stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR_2D{ + for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) { + for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) { + + const Tx* xTad = x + (nBatchIndex * xDimCstride); + Tz* zTad = z + (nBatchIndex * zDimCstride); + Tz& arg = zTad[nSampleIndexInBatch * zDimAstride]; + Tx Max = -minVal; + + auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + auto nClassesPerSample = nSampleIndexInBatch * numOfClassX; + for (auto nClass = 0; nClass < numOfClassX; nClass += 1) { + auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass; + auto unifornLog = nd4j::math::nd4j_log(-nd4j::math::nd4j_log(rng.relativeT(nIndex, minVal, maxVal))); + Tx tValue = (xTad[nClass * xDimAstride] - unifornLog); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1); + rng.rewindH(output.lengthOf()*numOfClassX); + + return; + } + + void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES); + } + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/random.cu b/libnd4j/include/ops/declarable/helpers/cuda/random.cu index 7014d6a50..1e290bc56 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/random.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/random.cu @@ -27,6 +27,8 @@ #include #include #include +#include +#include namespace nd4j { namespace ops { @@ -248,6 +250,116 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES); +/////////////////////////////////////////////////////////////////// +// used https://en.wikipedia.org/wiki/Categorical_distribution +// methods: gumbel trick + softmax + argmax +template +__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue, + const Nd4jLong numOfSamples, const Nd4jLong numOfClassX, + const Nd4jLong dimA, const X minVal, const X maxVal) { + + + const X* x = reinterpret_cast(vx); + Z* z = reinterpret_cast(vz); + + __shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC; + + if (0 == threadIdx.x) { + dimC = (0 == dimA) ? 1 : 0; + zDimAstride = shape::stride(zShapeInfo)[dimA]; + xDimAstride = shape::stride(xShapeInfo)[dimA]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + xDimCstride = shape::stride(xShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) { + + Nd4jLong nBatchIndex = index / numOfSamples; + Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples); + + const X* xTad = x + (nBatchIndex * xDimCstride); + Z* zTad = z + (nBatchIndex * zDimCstride); + Z& arg = zTad[nSampleIndexInBatch * zDimAstride]; + + X Max = -minVal; + Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples; + Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX; + + for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) { + Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass; + X tValue = (xTad[nClass * xDimAstride] - nd4j::math::nd4j_log(-nd4j::math::nd4j_log(devRng->relativeT(nIndex, minVal, maxVal)))); + if (tValue > Max) { + Max = tValue; + arg = nClass; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ static void fillMultiNomialCudaLauncher( + const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, + graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const Nd4jLong batchValue, const Nd4jLong numOfSamples, + const Nd4jLong numOfClassX, const Nd4jLong dimA){ + + const X minVal = DataTypeUtils::min(); + const X maxVal = 1.0; + + fillMultiNomialCuda_ <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> ( + devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue, + numOfSamples, numOfClassX, dimA, minVal, maxVal); +} + +/////////////////////////////////////////////////////////////////// +void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) { + + Nd4jLong dimA = (0 == dimC) ? 1 : 0; + + const Nd4jLong batchValue = output.sizeAt(dimC); + const Nd4jLong numOfClassX = input.sizeAt(dimA); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "fillMultinomial"); + graph::RandomGenerator *devRng; + + auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator)); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err); + } + err = cudaStreamSynchronize(*context->getCudaStream()); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err); + } + err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream()); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err); + } + + NDArray::prepareSpecialUse({ &output }, { &input }); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher, + (blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.getSpecialBuffer(), + input.getSpecialShapeInfo(), output.specialBuffer(), + output.specialShapeInfo(), batchValue, numOfSamples, + numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({ &output }, { &input }); + manager.synchronize(); + + err = cudaFree(devRng); + if (err != 0) { + cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err); + } + rng.rewindH(output.lengthOf() * numOfClassX); + } + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/random.h b/libnd4j/include/ops/declarable/helpers/random.h index db1b8ae53..c97aae118 100644 --- a/libnd4j/include/ops/declarable/helpers/random.h +++ b/libnd4j/include/ops/declarable/helpers/random.h @@ -34,6 +34,7 @@ namespace helpers { void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output); + void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC); } } } diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 1072b9dab..0e320c726 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -1002,4 +1002,205 @@ TEST_F(RNGTests, test_uniform_119) { nd4j::ops::randomuniform op; auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); ASSERT_EQ(Status::OK(), status); -} \ No newline at end of file +} + +TEST_F(RNGTests, test_multinomial_1) { + + NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64); + NDArray output('f', { 3, 3 }, nd4j::DataType::INT64); + NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32); + + nd4j::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) ); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); + + auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 }); + auto outputZ = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expectedZ.isSameShape(outputZ)); + ASSERT_TRUE(expectedZ.equalsTo(outputZ)); + delete result; +} + +TEST_F(RNGTests, test_multinomial_2) { + + NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32); + NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64); + NDArray output('c', { 3, 20 }, nd4j::DataType::INT64); + + nd4j::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); + NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64); + NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false)); + ASSERT_TRUE(expected2.isSameShape(output2)); + ASSERT_TRUE(expected2.equalsTo(output2)); +} + +TEST_F(RNGTests, test_multinomial_3) { + + NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64); + NDArray output('c', { 4, 5 }, nd4j::DataType::INT64); + NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + RandomGenerator rng(1234, 1234); + + nd4j::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_4) { + + NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); + NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64); + NDArray output('c', { 5, 4 }, nd4j::DataType::INT64); + NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); + + RandomGenerator rng(1234, 1234); + nd4j::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_5) { + // multinomial as binomial if 2 classes used + int batchValue = 1; + int ClassValue = 2; + int Samples = 1000000; + + NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32); + + NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32); + + nd4j::ops::random_multinomial op; + + NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64); + RandomGenerator rng(1234, 1234); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false)); + + auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + // theoretical values for binomial + ASSERT_NEAR(0.5, deviation.e(0), 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 3e-3); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + auto resultR = op.execute({ &probs, &samples }, { }, { 1 }); + auto outputR = resultR->at(0); + ASSERT_EQ(Status::OK(), resultR->status()); + + deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR->meanNumber(); + // printf("Random seed - Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(0.5, deviation.e(0), 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 35e-3); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + delete resultR; +} + + +TEST_F(RNGTests, test_multinomial_6) { + + int batchValue = 1; + int ClassValue = 5; + int Samples = 1000000; + + NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32); + + nd4j::ops::random_multinomial op; + NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE); + + // without seed + NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); + + auto resultR = op.execute({ &probsR, &samples }, { }, { 0 }); + auto outputR = resultR->at(0); + ASSERT_EQ(Status::OK(), resultR->status()); + + NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = countsR.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < countsR.lengthOf(); i++) { + auto c = countsR.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 35e-3); + } + + auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR->meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 35e-3); + + delete resultR; + + RandomGenerator rng(1234, 1234); + NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32); + NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false)); + + NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = counts.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < counts.lengthOf(); i++) { + auto c = counts.e(i); + auto p = probExpect.e(i); + // printf("Get freq : %f Expect freq: %f \n", c / Samples, p); + ASSERT_NEAR((c / Samples), p, 3e-3); + } + + deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = output.meanNumber(); + // printf("Var: %f Mean: %f \n", deviation.e(0), mean.e(0)); + ASSERT_NEAR(1.2175, deviation.e(0), 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 3e-3); +}