Oleh multinomial (#163)

* libnd4j: Multinomial op #8570 first raw step of multinomial random data generator implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op #8570 next step of multinomial random categories generator implementation on both cpu and cuda, need corrections and code clean up before review and testing

* libnd4j: Multinomial op #8570 code clean up and fixed issues data selecting, moved from coords to tads

* libnd4j: Multinomial op #8570 fixed cuda build add reference for math materials that was used for implementation

* libnd4j: Multinomial op #8570 fixed several bugs, added several tests and improved cuda version. current implementation works, need testing of reproduction with the same seed

* libnd4j: Multinomial op #8570 fixes and optimization after discussion in both cuda and cpu

* libnd4j: Multinomial op #8570 add corrections after review, removed tads, replace 2D parallel loop by 3D

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op fixed declaration and add tests need discussion

* libnd4j: Multinomial op fix in test

* libnd4j: Multinomial op corrected behavior to get reproducible results, fixed issue in uniform value getting, tests added, need cuda review and cuda testing

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op fixed indexing on uniform calculation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op some corrections in max min declaration

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op fixed index calculation, added rewind, corrected input declaration, added stats tests, both cuda and cpu. cuda need testing

* libnd4j: Multinomial op fixed bugs on cuda nad cpu. need review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op corrected tests to handle different orders

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op some improvements after code review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op more corrections after review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op fixed seed usage, update tests, fixed cuda based on comments, fixed bug of rewind, removed one behavior, minor corrections.

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op minor corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op rise the bound of fluctuation for random cases

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: Multinomial op modified operation inputs and update implementation and tests on both cpu and cuda

* libnd4j: Multinomial op corrected data types according ops.proto

Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-01-06 21:35:05 +02:00 committed by raver119
parent bb86bbc255
commit 2404be5fe0
6 changed files with 502 additions and 1 deletions

View File

@ -0,0 +1,114 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_random_multinomial)
#include <ops/declarable/CustomOperations.h>
#include <helpers/RandomLauncher.h>
#include <ops/declarable/helpers/random.h>
namespace nd4j {
namespace ops {
///////////////////////
/**
* multinomial (categorical) random generator
* takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)]
* and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N.
* represents the unnormalized log-probabilities for all classes.
* Int arguments: 0 - optional argument, corresponds to dimension with batch_size
* Int arguments: 1 - optional argument, integer type to use for the output. Default int64.
*/
// used https://en.wikipedia.org/wiki/Categorical_distribution
// methods: gumbel trick + softmax + argmax
CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
auto inputSamples = INPUT_VARIABLE(1);
REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. ");
REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample,"
" but got no argumets instead.");
Nd4jLong numOfSamples = static_cast<Nd4jLong>(inputSamples->e<int>(0));
// do nothing if number of samples = 0
if (0 == numOfSamples)
return Status::OK();
REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples);
const int rank = input->rankOf();
REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank);
const int argSize = block.getIArguments()->size();
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
auto dimA = (0 == dimC) ? 1 : 0;
if (1 == input->sizeAt(dimA)) {
*output = 0;
return Status::OK();
}
auto rng = block.randomGenerator();
helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC);
return Status::OK();
}
DECLARE_SHAPE_FN(random_multinomial) {
auto input = INPUT_VARIABLE(0);
auto inputSamples = INPUT_VARIABLE(1);
REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample,"
" but got no argumets instead.");
Nd4jLong numOfSamples = static_cast<Nd4jLong>(inputSamples->e<int>(0));
REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples);
const int rank = input->rankOf();
REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank);
const int argSize = block.getIArguments()->size();
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
auto nShape = input->getShapeAsVector();
auto dimA = (0 == dimC) ? 1 : 0;
nShape[dimA] = numOfSamples;
DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast<DataType>(INT_ARG(1)) : nd4j::DataType::INT64) : nd4j::DataType::INT64;
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape));
}
DECLARE_TYPES(random_multinomial) {
getOpDescriptor()
->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS })
->setAllowedInputTypes(1, { nd4j::DataType::INT32 })
->setAllowedOutputTypes(0, { ALL_INDICES });
}
}
}
#endif

View File

@ -49,6 +49,22 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_randomuniform) #if NOT_EXCLUDED(OP_randomuniform)
DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
#endif #endif
/*
* multinomial (categorical) random generator draws samples from a multinomial distribution
*
* Input array:
* 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)]
* 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N.
* Int arguments:
* 0 - optional argument, corresponds to dimension with batch_size
* 1 - optional argument, integer type to use for the output. Default int64.
*
* Output array:
* 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples]
*/
#if NOT_EXCLUDED(OP_random_multinomial)
DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0);
#endif
#if NOT_EXCLUDED(OP_random_normal) #if NOT_EXCLUDED(OP_random_normal)
DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0); DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0);

View File

@ -24,6 +24,8 @@
//#include <graph/Context.h> //#include <graph/Context.h>
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <helpers/RandomLauncher.h> #include <helpers/RandomLauncher.h>
#include <execution/Threads.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -150,6 +152,61 @@ namespace helpers {
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) { void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
} }
// used https://en.wikipedia.org/wiki/Categorical_distribution
// methods: gumbel trick + softmax + argmax
template <typename Tx, typename Tz>
void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
const Tx* x = input.bufferAsT<Tx>();
Tz* z = output.bufferAsT<Tz>();
Tx minVal = DataTypeUtils::min<Tx>();
Tx maxVal = 1.0;
auto dimA = (0 == dimC) ? 1 : 0;
const Nd4jLong batchValue = output.sizeAt(dimC);
const Nd4jLong numOfClassX = input.sizeAt(dimA);
const Nd4jLong zDimAstride = output.stridesOf()[dimA];
const Nd4jLong xDimAstride = input.stridesOf()[dimA];
const Nd4jLong zDimCstride = output.stridesOf()[dimC];
const Nd4jLong xDimCstride = input.stridesOf()[dimC];
auto func = PRAGMA_THREADS_FOR_2D{
for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) {
for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) {
const Tx* xTad = x + (nBatchIndex * xDimCstride);
Tz* zTad = z + (nBatchIndex * zDimCstride);
Tz& arg = zTad[nSampleIndexInBatch * zDimAstride];
Tx Max = -minVal;
auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples;
auto nClassesPerSample = nSampleIndexInBatch * numOfClassX;
for (auto nClass = 0; nClass < numOfClassX; nClass += 1) {
auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass;
auto unifornLog = nd4j::math::nd4j_log<Tx, Tx>(-nd4j::math::nd4j_log<Tx, Tx>(rng.relativeT<Tx>(nIndex, minVal, maxVal)));
Tx tValue = (xTad[nClass * xDimAstride] - unifornLog);
if (tValue > Max) {
Max = tValue;
arg = nClass;
}
}
}
}
};
samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
rng.rewindH(output.lengthOf()*numOfClassX);
return;
}
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES);
}
} }
} }
} }

View File

@ -27,6 +27,8 @@
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <cuda_exception.h> #include <cuda_exception.h>
#include <helpers/ConstantTadHelper.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -248,6 +250,116 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context, BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES); graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
///////////////////////////////////////////////////////////////////
// used https://en.wikipedia.org/wiki/Categorical_distribution
// methods: gumbel trick + softmax + argmax
template<typename X, typename Z>
__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue,
const Nd4jLong numOfSamples, const Nd4jLong numOfClassX,
const Nd4jLong dimA, const X minVal, const X maxVal) {
const X* x = reinterpret_cast<const X*>(vx);
Z* z = reinterpret_cast<Z*>(vz);
__shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC;
if (0 == threadIdx.x) {
dimC = (0 == dimA) ? 1 : 0;
zDimAstride = shape::stride(zShapeInfo)[dimA];
xDimAstride = shape::stride(xShapeInfo)[dimA];
zDimCstride = shape::stride(zShapeInfo)[dimC];
xDimCstride = shape::stride(xShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) {
Nd4jLong nBatchIndex = index / numOfSamples;
Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples);
const X* xTad = x + (nBatchIndex * xDimCstride);
Z* zTad = z + (nBatchIndex * zDimCstride);
Z& arg = zTad[nSampleIndexInBatch * zDimAstride];
X Max = -minVal;
Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples;
Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX;
for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) {
Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass;
X tValue = (xTad[nClass * xDimAstride] - nd4j::math::nd4j_log<X, X>(-nd4j::math::nd4j_log<X, X>(devRng->relativeT<X>(nIndex, minVal, maxVal))));
if (tValue > Max) {
Max = tValue;
arg = nClass;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
__host__ static void fillMultiNomialCudaLauncher(
const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const Nd4jLong batchValue, const Nd4jLong numOfSamples,
const Nd4jLong numOfClassX, const Nd4jLong dimA){
const X minVal = DataTypeUtils::min<X>();
const X maxVal = 1.0;
fillMultiNomialCuda_<X, Z> <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> (
devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue,
numOfSamples, numOfClassX, dimA, minVal, maxVal);
}
///////////////////////////////////////////////////////////////////
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
Nd4jLong dimA = (0 == dimC) ? 1 : 0;
const Nd4jLong batchValue = output.sizeAt(dimC);
const Nd4jLong numOfClassX = input.sizeAt(dimA);
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "fillMultinomial");
graph::RandomGenerator *devRng;
auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator));
if (err != 0) {
cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err);
}
err = cudaStreamSynchronize(*context->getCudaStream());
if (err != 0) {
cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err);
}
err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream());
if (err != 0) {
cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err);
}
NDArray::prepareSpecialUse({ &output }, { &input });
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher,
(blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.getSpecialBuffer(),
input.getSpecialShapeInfo(), output.specialBuffer(),
output.specialShapeInfo(), batchValue, numOfSamples,
numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({ &output }, { &input });
manager.synchronize();
err = cudaFree(devRng);
if (err != 0) {
cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err);
}
rng.rewindH(output.lengthOf() * numOfClassX);
}
} }
} }
} }

View File

@ -34,6 +34,7 @@ namespace helpers {
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output); void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output); void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output); void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC);
} }
} }
} }

View File

@ -1003,3 +1003,204 @@ TEST_F(RNGTests, test_uniform_119) {
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
TEST_F(RNGTests, test_multinomial_1) {
NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64);
NDArray output('f', { 3, 3 }, nd4j::DataType::INT64);
NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32);
nd4j::ops::random_multinomial op;
RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) );
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 });
auto outputZ = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
delete result;
}
TEST_F(RNGTests, test_multinomial_2) {
NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32);
NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64);
NDArray output('c', { 3, 20 }, nd4j::DataType::INT64);
nd4j::ops::random_multinomial op;
RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64);
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false));
ASSERT_TRUE(expected2.isSameShape(output2));
ASSERT_TRUE(expected2.equalsTo(output2));
}
TEST_F(RNGTests, test_multinomial_3) {
NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64);
NDArray output('c', { 4, 5 }, nd4j::DataType::INT64);
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op;
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false));
rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
TEST_F(RNGTests, test_multinomial_4) {
NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64);
NDArray output('c', { 5, 4 }, nd4j::DataType::INT64);
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op;
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false));
rng.setStates(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false));
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
TEST_F(RNGTests, test_multinomial_5) {
// multinomial as binomial if 2 classes used
int batchValue = 1;
int ClassValue = 2;
int Samples = 1000000;
NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32);
NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32);
nd4j::ops::random_multinomial op;
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
RandomGenerator rng(1234, 1234);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false));
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
auto mean = output.meanNumber();
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
// theoretical values for binomial
ASSERT_NEAR(0.5, deviation.e<double>(0), 3e-3);
ASSERT_NEAR(0.5, mean.e<double>(0), 3e-3);
for (int i = 0; i < output.lengthOf(); i++) {
auto value = output.e<Nd4jLong>(i);
ASSERT_TRUE(value >= 0 && value < ClassValue);
}
auto resultR = op.execute({ &probs, &samples }, { }, { 1 });
auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status());
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
mean = outputR->meanNumber();
// printf("Random seed - Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
ASSERT_NEAR(0.5, deviation.e<double>(0), 35e-3);
ASSERT_NEAR(0.5, mean.e<double>(0), 35e-3);
for (int i = 0; i < outputR->lengthOf(); i++) {
auto value = outputR->e<Nd4jLong>(i);
ASSERT_TRUE(value >= 0 && value < ClassValue);
}
delete resultR;
}
TEST_F(RNGTests, test_multinomial_6) {
int batchValue = 1;
int ClassValue = 5;
int Samples = 1000000;
NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32);
nd4j::ops::random_multinomial op;
NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE);
// without seed
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
auto resultR = op.execute({ &probsR, &samples }, { }, { 0 });
auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status());
NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
for (int i = 0; i < outputR->lengthOf(); i++) {
auto value = outputR->e<Nd4jLong>(i);
ASSERT_TRUE(value >= 0 && value < ClassValue);
double* z = countsR.bufferAsT<double>();
z[value] += 1;
}
for (int i = 0; i < countsR.lengthOf(); i++) {
auto c = countsR.e<double>(i);
auto p = probExpect.e<double>(i);
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
ASSERT_NEAR((c / Samples), p, 35e-3);
}
auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
auto mean = outputR->meanNumber();
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
ASSERT_NEAR(1.2175, deviation.e<double>(0), 35e-3);
ASSERT_NEAR(2.906, mean.e<double>(0), 35e-3);
delete resultR;
RandomGenerator rng(1234, 1234);
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
for (int i = 0; i < output.lengthOf(); i++) {
auto value = output.e<Nd4jLong>(i);
ASSERT_TRUE(value >= 0 && value < ClassValue);
double* z = counts.bufferAsT<double>();
z[value] += 1;
}
for (int i = 0; i < counts.lengthOf(); i++) {
auto c = counts.e<double>(i);
auto p = probExpect.e<double>(i);
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
ASSERT_NEAR((c / Samples), p, 3e-3);
}
deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
mean = output.meanNumber();
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
ASSERT_NEAR(1.2175, deviation.e<double>(0), 3e-3);
ASSERT_NEAR(2.906, mean.e<double>(0), 3e-3);
}